fix: 修复 remove_task 未释放 KV cache slot 导致第二轮对话死锁

- remove_task() 现在释放 KV cache slot 和 prefix cache 引用
- _refill_active_batch 中 alloc 失败时将剩余 task 推回 waiting_queue
- 主循环增加 try/except 异常兜底,发送 _STOP 给所有 task
- 重构:server.py 全局变量改为 ServerState 类;automodel.py
  使用 Registry 替代裸 dict;合并 TrainContextBuilder 的 with_*
  方法到 build()
This commit is contained in:
ViperEkura 2026-05-08 14:53:04 +08:00
parent ffff05b2c6
commit a6f5ff3b37
8 changed files with 165 additions and 142 deletions

View File

@ -2,7 +2,6 @@
import asyncio
import gc
import logging
import threading
from typing import Any, AsyncGenerator, Dict, Generator, List, Optional, Union
@ -12,8 +11,6 @@ import torch.nn as nn
from astrai.inference.scheduler import _STOP, InferenceScheduler
from astrai.tokenize import AutoTokenizer
logger = logging.getLogger(__name__)
class GenerationRequest:
"""Request parameters for text generation.

View File

@ -1,5 +1,6 @@
"""Inference scheduler for single-GPU continuous batching."""
import logging
import threading
import time
import uuid
@ -12,6 +13,8 @@ from torch import Tensor
from astrai.model.automodel import AutoModel
from astrai.tokenize import AutoTokenizer
logger = logging.getLogger(__name__)
_STOP = object()
@ -506,9 +509,20 @@ class InferenceScheduler:
task_id: The task to remove.
"""
with self._lock:
removed_active = [t for t in self.active_tasks if t.task_id == task_id]
self.waiting_queue = [t for t in self.waiting_queue if t.task_id != task_id]
self.active_tasks = [t for t in self.active_tasks if t.task_id != task_id]
for task in removed_active:
if task.prefix_len > 0:
prefix = tuple(task.prompt_ids[: task.prefix_len])
self.prefix_cache.release(prefix)
if task.prefix_len < len(task.prompt_ids):
self.prefix_cache.release(tuple(task.prompt_ids))
if task.slot >= 0:
self._free_slot(task.slot)
task.slot = -1
def _remove_finished_tasks(self) -> None:
"""Removes all finished tasks from the active batch.
@ -553,7 +567,7 @@ class InferenceScheduler:
for _ in range(n):
to_add.append(self.waiting_queue.pop(0))
for task in to_add:
for i, task in enumerate(to_add):
slot = -1
reused = False
if task.prefix_len > 0:
@ -564,6 +578,8 @@ class InferenceScheduler:
if slot < 0:
slot = self._alloc_slot()
if slot < 0:
with self._lock:
self.waiting_queue[:0] = to_add[i:]
break
task.slot = slot
task.status = TaskStatus.RUNNING
@ -712,6 +728,7 @@ class InferenceScheduler:
Decode processes only the largest position group to ensure all
batched tasks share the same KV cache write position.
"""
try:
while self._running:
self._remove_finished_tasks()
self._refill_active_batch()
@ -738,6 +755,15 @@ class InferenceScheduler:
if not self.waiting_queue and len(self.active_tasks) <= 1:
self._task_event.wait(timeout=0.005)
self._task_event.clear()
except Exception as e:
logger.error(f"Scheduler loop crashed: {e}", exc_info=True)
for task in self.active_tasks:
if task.stream_callback:
task.stream_callback(_STOP)
for task in self.waiting_queue:
if task.stream_callback:
task.stream_callback(_STOP)
raise
def start(self) -> None:
"""Starts the background generation loop thread."""

View File

@ -23,11 +23,22 @@ from astrai.tokenize import AutoTokenizer
logger = logging.getLogger(__name__)
_engine: Optional[InferenceEngine] = None
_model_param: Optional[Any] = None
_project_root = Path(__file__).parent.parent.parent
_server_config: Dict[str, Any] = {
class ServerState:
"""Encapsulates all server runtime state.
Attributes:
engine: The inference engine instance.
model_param: The loaded model.
config: Server configuration dict.
"""
def __init__(self):
self.engine: Optional[InferenceEngine] = None
self.model_param: Optional[Any] = None
self.config: Dict[str, Any] = {
"device": "cuda",
"dtype": torch.bfloat16,
"param_path": None,
@ -35,34 +46,38 @@ _server_config: Dict[str, Any] = {
}
_state = ServerState()
def configure_server(
device: str = "cuda",
dtype: torch.dtype = torch.bfloat16,
param_path: Optional[Path] = None,
max_batch_size: int = 16,
):
_server_config["device"] = device
_server_config["dtype"] = dtype
_server_config["param_path"] = param_path
_server_config["max_batch_size"] = max_batch_size
_state.config.update(
device=device,
dtype=dtype,
param_path=param_path,
max_batch_size=max_batch_size,
)
@asynccontextmanager
async def lifespan(app: FastAPI):
global _model_param, _engine
try:
load_model(
param_path=_server_config["param_path"],
device=_server_config["device"],
dtype=_server_config["dtype"],
max_batch_size=_server_config["max_batch_size"],
param_path=_state.config["param_path"],
device=_state.config["device"],
dtype=_state.config["dtype"],
max_batch_size=_state.config["max_batch_size"],
)
except Exception as e:
logger.error(f"Failed to load model: {e}")
raise
yield
if _engine:
_engine.shutdown()
if _state.engine:
_state.engine.shutdown()
logger.info("Inference engine shutdown complete")
@ -75,25 +90,30 @@ def load_model(
dtype: torch.dtype = torch.bfloat16,
max_batch_size: int = 16,
):
global _model_param, _engine
if param_path is None:
param_path = _project_root / "params"
if not param_path.exists():
raise FileNotFoundError(f"Parameter directory not found: {param_path}")
tokenizer = AutoTokenizer.from_pretrained(param_path)
_model_param = AutoModel.from_pretrained(param_path)
_model_param.to(device=device, dtype=dtype)
_state.model_param = AutoModel.from_pretrained(param_path)
_state.model_param.to(device=device, dtype=dtype)
logger.info(f"Model loaded on {device} with dtype {dtype}")
_engine = InferenceEngine(
model=_model_param,
_state.engine = InferenceEngine(
model=_state.model_param,
tokenizer=tokenizer,
max_batch_size=max_batch_size,
)
logger.info(f"Inference engine initialized with max_batch_size={max_batch_size}")
def _get_engine() -> InferenceEngine:
if _state.engine is None:
raise HTTPException(status_code=503, detail="Engine not initialized")
return _state.engine
class ChatMessage(BaseModel):
role: str
content: str
@ -121,30 +141,27 @@ class CompletionResponse(BaseModel):
async def health():
return {
"status": "ok",
"model_loaded": _model_param is not None,
"engine_ready": _engine is not None,
"model_loaded": _state.model_param is not None,
"engine_ready": _state.engine is not None,
}
@app.get("/stats")
async def get_stats():
if _engine is None:
raise HTTPException(status_code=503, detail="Engine not initialized")
return _engine.get_stats()
return _get_engine().get_stats()
@app.post("/v1/chat/completions", response_model=CompletionResponse)
async def chat_completion(request: ChatCompletionRequest):
if _engine is None:
raise HTTPException(status_code=503, detail="Engine not initialized")
engine = _get_engine()
prompt = _engine.tokenizer.apply_chat_template(
prompt = engine.tokenizer.apply_chat_template(
[{"role": m.role, "content": m.content} for m in request.messages],
tokenize=False,
)
if request.stream:
agen = _engine.generate_async(
agen = engine.generate_async(
prompt=prompt,
max_tokens=request.max_tokens,
temperature=request.temperature,
@ -163,7 +180,7 @@ async def chat_completion(request: ChatCompletionRequest):
headers={"Cache-Control": "no-cache", "Connection": "keep-alive"},
)
else:
result = _engine.generate(
result = engine.generate(
prompt=prompt,
stream=False,
max_tokens=request.max_tokens,
@ -198,8 +215,7 @@ async def generate(
max_len: int = 2048,
stream: bool = False,
):
if _engine is None:
raise HTTPException(status_code=503, detail="Engine not initialized")
engine = _get_engine()
messages = []
if history:
@ -209,10 +225,10 @@ async def generate(
messages.append({"role": "assistant", "content": h[1]})
messages.append({"role": "user", "content": query})
prompt = _engine.tokenizer.apply_chat_template(messages, tokenize=False)
prompt = engine.tokenizer.apply_chat_template(messages, tokenize=False)
if stream:
agen = _engine.generate_async(
agen = engine.generate_async(
prompt=prompt,
max_tokens=max_len,
temperature=temperature,
@ -226,7 +242,7 @@ async def generate(
return StreamingResponse(text_stream(), media_type="text/plain")
else:
result = _engine.generate(
result = engine.generate(
prompt=prompt,
stream=False,
max_tokens=max_len,

View File

@ -4,12 +4,13 @@ AutoModel base class for model loading and saving.
from contextlib import contextmanager
from pathlib import Path
from typing import Dict, Self, Type, Union
from typing import Self, Type, Union
import safetensors.torch as st
import torch.nn as nn
from astrai.config import ModelConfig
from astrai.factory import Registry
@contextmanager
@ -44,8 +45,7 @@ class AutoModel(nn.Module):
Provides model loading/saving and generation capabilities.
"""
# Model registry - stored as class attribute
_registry: Dict[str, Type["AutoModel"]] = {}
_registry = Registry()
def __init__(self, config: ModelConfig):
super().__init__()
@ -63,7 +63,7 @@ class AutoModel(nn.Module):
"""
def decorator(sub_cls: Type["AutoModel"]) -> Type["AutoModel"]:
cls._registry[model_type.lower()] = sub_cls
cls._registry.register(model_type.lower(), sub_cls)
return sub_cls
return decorator
@ -72,12 +72,12 @@ class AutoModel(nn.Module):
def get_model_class(cls, model_type: str) -> Type["AutoModel"]:
"""Get model class by model_type string."""
model_type = model_type.lower()
if model_type not in cls._registry:
available = list(cls._registry.keys())
if not cls._registry.contains(model_type):
available = cls._registry.list_names()
raise ValueError(
f"Unknown model_type: {model_type}. Available: {available}"
)
return cls._registry[model_type]
return cls._registry.get(model_type)
@classmethod
def from_pretrained(
@ -96,14 +96,8 @@ class AutoModel(nn.Module):
else:
raise FileNotFoundError(f"Config file not found: {config_path}")
# If called from base class, use model_type to determine actual model class
if cls is AutoModel:
model_type = config.model_type or "transformer"
actual_cls = cls.get_model_class(model_type)
else:
raise ValueError(
f"Cannot call from_pretrained() on subclass {cls.__name__}"
)
with _disable_random_init(enable=disable_random_init):
model = actual_cls(config)

View File

@ -34,66 +34,60 @@ class TrainContext:
class TrainContextBuilder:
def __init__(self, config: TrainConfig):
self.config = config
self._context = TrainContext(
model=config.model,
self._checkpoint: Optional[Checkpoint] = None
def with_checkpoint(self, checkpoint: Optional[Checkpoint]) -> Self:
self._checkpoint = checkpoint
return self
def build(self) -> TrainContext:
context = TrainContext(
model=self.config.model,
world_size=get_world_size(),
rank=get_rank(),
)
device = get_current_device()
self._context.model = self._context.model.to(device=device)
context.model = context.model.to(device=device)
if self.config.nprocs > 1:
fn = self.config.parallel_wrapper
self._context.model = fn(self._context.model)
if self.config.nprocs > 1 and self.config.parallel_wrapper:
context.model = self.config.parallel_wrapper(context.model)
self._context.optimizer = self.config.optimizer_fn(self._context.model)
self._context.scheduler = self.config.scheduler_fn(self._context.optimizer)
def with_checkpoint(self, checkpoint: Optional[Checkpoint]) -> Self:
if checkpoint is None:
checkpoint = Checkpoint(
state_dict=self._context.model.state_dict(),
)
if self._checkpoint is not None:
context.epoch = max(self._checkpoint.epoch, self.config.start_epoch)
context.iteration = max(self._checkpoint.iteration, self.config.start_batch)
context.model.load_state_dict(self._checkpoint.state_dict)
context.checkpoint = self._checkpoint
else:
# resume from the assigned checkpoint or assigned iteration
self._context.epoch = max(checkpoint.epoch, self.config.start_epoch)
self._context.iteration = max(checkpoint.iteration, self.config.start_batch)
self._context.model.load_state_dict(checkpoint.state_dict)
context.checkpoint = Checkpoint(
state_dict=context.model.state_dict(),
)
self._context.checkpoint = checkpoint
return self
context.optimizer = self.config.optimizer_fn(context.model)
context.scheduler = self.config.scheduler_fn(context.optimizer)
def with_dataloader(self) -> Self:
# fix: change batch level iteration to sample level offset
config = self.config
sampler_offset = self._context.iteration * config.batch_size
resumeable_sampler = ResumableDistributedSampler(
data_source=config.dataset,
start_epoch=self._context.epoch,
cfg = self.config
sampler_offset = context.iteration * cfg.batch_size
sampler = ResumableDistributedSampler(
data_source=cfg.dataset,
start_epoch=context.epoch,
start_iter=sampler_offset,
seed=config.random_seed,
seed=cfg.random_seed,
)
context.dataloader = DataLoader(
cfg.dataset,
batch_size=cfg.batch_size,
sampler=sampler,
num_workers=cfg.num_workers,
pin_memory=cfg.pin_memory,
prefetch_factor=cfg.prefetch_factor,
)
dataloader = DataLoader(
config.dataset,
batch_size=config.batch_size,
sampler=resumeable_sampler,
num_workers=config.num_workers,
pin_memory=config.pin_memory,
prefetch_factor=config.prefetch_factor,
)
self._context.dataloader = dataloader
return self
def with_strategy(self) -> Self:
self._context.strategy = StrategyFactory.create(
model=self._context.model,
context.strategy = StrategyFactory.create(
model=context.model,
train_type=self.config.strategy,
device=get_current_device(),
device=device,
**self.config.extra_kwargs,
)
return self
def build(self) -> TrainContext:
return self._context
return context

View File

@ -35,11 +35,7 @@ class Trainer:
def _build_context(self, checkpoint: Optional[Checkpoint]) -> TrainContext:
return (
TrainContextBuilder(self.train_config)
.with_checkpoint(checkpoint)
.with_dataloader()
.with_strategy()
.build()
TrainContextBuilder(self.train_config).with_checkpoint(checkpoint).build()
)
def _call_callbacks(self, method_name: str, context: TrainContext):

View File

@ -53,5 +53,5 @@ def mock_engine():
@pytest.fixture
def loaded_model(mock_model_param, monkeypatch):
"""Simulate that the model is loaded."""
monkeypatch.setattr("astrai.inference.server._model_param", mock_model_param)
monkeypatch.setattr("astrai.inference.server._state.model_param", mock_model_param)
return mock_model_param

View File

@ -5,8 +5,8 @@ import pytest
def test_health_no_model(client, monkeypatch):
"""GET /health should return 200 even when model not loaded."""
monkeypatch.setattr("astrai.inference.server._model_param", None)
monkeypatch.setattr("astrai.inference.server._engine", None)
monkeypatch.setattr("astrai.inference.server._state.model_param", None)
monkeypatch.setattr("astrai.inference.server._state.engine", None)
response = client.get("/health")
assert response.status_code == 200
data = response.json()
@ -17,7 +17,7 @@ def test_health_no_model(client, monkeypatch):
def test_health_with_model(client, loaded_model, mock_engine, monkeypatch):
"""GET /health should return 200 when model is loaded."""
monkeypatch.setattr("astrai.inference.server._engine", mock_engine)
monkeypatch.setattr("astrai.inference.server._state.engine", mock_engine)
response = client.get("/health")
assert response.status_code == 200
data = response.json()
@ -28,7 +28,7 @@ def test_health_with_model(client, loaded_model, mock_engine, monkeypatch):
def test_generate_non_stream(client, loaded_model, mock_engine, monkeypatch):
"""POST /generate with stream=false should return JSON response."""
monkeypatch.setattr("astrai.inference.server._engine", mock_engine)
monkeypatch.setattr("astrai.inference.server._state.engine", mock_engine)
response = client.post(
"/generate",
params={
@ -54,7 +54,7 @@ def test_generate_stream(client, loaded_model, mock_engine, monkeypatch):
yield "chunk2"
mock_engine.generate.return_value = stream_gen()
monkeypatch.setattr("astrai.inference.server._engine", mock_engine)
monkeypatch.setattr("astrai.inference.server._state.engine", mock_engine)
response = client.post(
"/generate",
params={
@ -78,7 +78,7 @@ def test_generate_stream(client, loaded_model, mock_engine, monkeypatch):
def test_chat_completions_non_stream(client, loaded_model, mock_engine, monkeypatch):
"""POST /v1/chat/completions with stream=false returns OpenAIstyle JSON."""
mock_engine.generate.return_value = "Assistant reply"
monkeypatch.setattr("astrai.inference.server._engine", mock_engine)
monkeypatch.setattr("astrai.inference.server._state.engine", mock_engine)
response = client.post(
"/v1/chat/completions",
json={
@ -106,7 +106,7 @@ def test_chat_completions_stream(client, loaded_model, mock_engine, monkeypatch)
yield "[DONE]"
mock_engine.generate_async.return_value = async_gen()
monkeypatch.setattr("astrai.inference.server._engine", mock_engine)
monkeypatch.setattr("astrai.inference.server._state.engine", mock_engine)
response = client.post(
"/v1/chat/completions",
json={
@ -132,7 +132,7 @@ def test_chat_completions_stream(client, loaded_model, mock_engine, monkeypatch)
def test_generate_with_history(client, loaded_model, mock_engine, monkeypatch):
"""POST /generate with history parameter."""
monkeypatch.setattr("astrai.inference.server._engine", mock_engine)
monkeypatch.setattr("astrai.inference.server._state.engine", mock_engine)
response = client.post(
"/generate",
params={