fix: 修复 remove_task 未释放 KV cache slot 导致第二轮对话死锁

- remove_task() 现在释放 KV cache slot 和 prefix cache 引用
- _refill_active_batch 中 alloc 失败时将剩余 task 推回 waiting_queue
- 主循环增加 try/except 异常兜底,发送 _STOP 给所有 task
- 重构:server.py 全局变量改为 ServerState 类;automodel.py
  使用 Registry 替代裸 dict;合并 TrainContextBuilder 的 with_*
  方法到 build()
This commit is contained in:
ViperEkura 2026-05-08 14:53:04 +08:00
parent ffff05b2c6
commit a6f5ff3b37
8 changed files with 165 additions and 142 deletions

View File

@ -2,7 +2,6 @@
import asyncio import asyncio
import gc import gc
import logging
import threading import threading
from typing import Any, AsyncGenerator, Dict, Generator, List, Optional, Union from typing import Any, AsyncGenerator, Dict, Generator, List, Optional, Union
@ -12,8 +11,6 @@ import torch.nn as nn
from astrai.inference.scheduler import _STOP, InferenceScheduler from astrai.inference.scheduler import _STOP, InferenceScheduler
from astrai.tokenize import AutoTokenizer from astrai.tokenize import AutoTokenizer
logger = logging.getLogger(__name__)
class GenerationRequest: class GenerationRequest:
"""Request parameters for text generation. """Request parameters for text generation.

View File

@ -1,5 +1,6 @@
"""Inference scheduler for single-GPU continuous batching.""" """Inference scheduler for single-GPU continuous batching."""
import logging
import threading import threading
import time import time
import uuid import uuid
@ -12,6 +13,8 @@ from torch import Tensor
from astrai.model.automodel import AutoModel from astrai.model.automodel import AutoModel
from astrai.tokenize import AutoTokenizer from astrai.tokenize import AutoTokenizer
logger = logging.getLogger(__name__)
_STOP = object() _STOP = object()
@ -506,9 +509,20 @@ class InferenceScheduler:
task_id: The task to remove. task_id: The task to remove.
""" """
with self._lock: with self._lock:
removed_active = [t for t in self.active_tasks if t.task_id == task_id]
self.waiting_queue = [t for t in self.waiting_queue if t.task_id != task_id] self.waiting_queue = [t for t in self.waiting_queue if t.task_id != task_id]
self.active_tasks = [t for t in self.active_tasks if t.task_id != task_id] self.active_tasks = [t for t in self.active_tasks if t.task_id != task_id]
for task in removed_active:
if task.prefix_len > 0:
prefix = tuple(task.prompt_ids[: task.prefix_len])
self.prefix_cache.release(prefix)
if task.prefix_len < len(task.prompt_ids):
self.prefix_cache.release(tuple(task.prompt_ids))
if task.slot >= 0:
self._free_slot(task.slot)
task.slot = -1
def _remove_finished_tasks(self) -> None: def _remove_finished_tasks(self) -> None:
"""Removes all finished tasks from the active batch. """Removes all finished tasks from the active batch.
@ -553,7 +567,7 @@ class InferenceScheduler:
for _ in range(n): for _ in range(n):
to_add.append(self.waiting_queue.pop(0)) to_add.append(self.waiting_queue.pop(0))
for task in to_add: for i, task in enumerate(to_add):
slot = -1 slot = -1
reused = False reused = False
if task.prefix_len > 0: if task.prefix_len > 0:
@ -564,6 +578,8 @@ class InferenceScheduler:
if slot < 0: if slot < 0:
slot = self._alloc_slot() slot = self._alloc_slot()
if slot < 0: if slot < 0:
with self._lock:
self.waiting_queue[:0] = to_add[i:]
break break
task.slot = slot task.slot = slot
task.status = TaskStatus.RUNNING task.status = TaskStatus.RUNNING
@ -712,32 +728,42 @@ class InferenceScheduler:
Decode processes only the largest position group to ensure all Decode processes only the largest position group to ensure all
batched tasks share the same KV cache write position. batched tasks share the same KV cache write position.
""" """
while self._running: try:
self._remove_finished_tasks() while self._running:
self._refill_active_batch() self._remove_finished_tasks()
self._refill_active_batch()
with self._lock: with self._lock:
if not self.active_tasks and not self.waiting_queue: if not self.active_tasks and not self.waiting_queue:
self._task_event.clear()
self._task_event.wait(timeout=0.01)
continue
tasks = self.active_tasks[:]
to_prefill = [t for t in tasks if t.output_tokens == 0]
if to_prefill:
self._execute_prefill(to_prefill)
pos_groups: Dict[int, List[Task]] = {}
for t in self.active_tasks:
pos_groups.setdefault(t.next_pos, []).append(t)
if pos_groups:
best_pos = max(pos_groups, key=lambda p: len(pos_groups[p]))
self._execute_decode(pos_groups[best_pos], best_pos)
if not self.waiting_queue and len(self.active_tasks) <= 1:
self._task_event.wait(timeout=0.005)
self._task_event.clear() self._task_event.clear()
self._task_event.wait(timeout=0.01) except Exception as e:
continue logger.error(f"Scheduler loop crashed: {e}", exc_info=True)
tasks = self.active_tasks[:] for task in self.active_tasks:
if task.stream_callback:
to_prefill = [t for t in tasks if t.output_tokens == 0] task.stream_callback(_STOP)
if to_prefill: for task in self.waiting_queue:
self._execute_prefill(to_prefill) if task.stream_callback:
task.stream_callback(_STOP)
pos_groups: Dict[int, List[Task]] = {} raise
for t in self.active_tasks:
pos_groups.setdefault(t.next_pos, []).append(t)
if pos_groups:
best_pos = max(pos_groups, key=lambda p: len(pos_groups[p]))
self._execute_decode(pos_groups[best_pos], best_pos)
if not self.waiting_queue and len(self.active_tasks) <= 1:
self._task_event.wait(timeout=0.005)
self._task_event.clear()
def start(self) -> None: def start(self) -> None:
"""Starts the background generation loop thread.""" """Starts the background generation loop thread."""

View File

@ -23,16 +23,30 @@ from astrai.tokenize import AutoTokenizer
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
_engine: Optional[InferenceEngine] = None
_model_param: Optional[Any] = None
_project_root = Path(__file__).parent.parent.parent _project_root = Path(__file__).parent.parent.parent
_server_config: Dict[str, Any] = {
"device": "cuda", class ServerState:
"dtype": torch.bfloat16, """Encapsulates all server runtime state.
"param_path": None,
"max_batch_size": 16, Attributes:
} engine: The inference engine instance.
model_param: The loaded model.
config: Server configuration dict.
"""
def __init__(self):
self.engine: Optional[InferenceEngine] = None
self.model_param: Optional[Any] = None
self.config: Dict[str, Any] = {
"device": "cuda",
"dtype": torch.bfloat16,
"param_path": None,
"max_batch_size": 16,
}
_state = ServerState()
def configure_server( def configure_server(
@ -41,28 +55,29 @@ def configure_server(
param_path: Optional[Path] = None, param_path: Optional[Path] = None,
max_batch_size: int = 16, max_batch_size: int = 16,
): ):
_server_config["device"] = device _state.config.update(
_server_config["dtype"] = dtype device=device,
_server_config["param_path"] = param_path dtype=dtype,
_server_config["max_batch_size"] = max_batch_size param_path=param_path,
max_batch_size=max_batch_size,
)
@asynccontextmanager @asynccontextmanager
async def lifespan(app: FastAPI): async def lifespan(app: FastAPI):
global _model_param, _engine
try: try:
load_model( load_model(
param_path=_server_config["param_path"], param_path=_state.config["param_path"],
device=_server_config["device"], device=_state.config["device"],
dtype=_server_config["dtype"], dtype=_state.config["dtype"],
max_batch_size=_server_config["max_batch_size"], max_batch_size=_state.config["max_batch_size"],
) )
except Exception as e: except Exception as e:
logger.error(f"Failed to load model: {e}") logger.error(f"Failed to load model: {e}")
raise raise
yield yield
if _engine: if _state.engine:
_engine.shutdown() _state.engine.shutdown()
logger.info("Inference engine shutdown complete") logger.info("Inference engine shutdown complete")
@ -75,25 +90,30 @@ def load_model(
dtype: torch.dtype = torch.bfloat16, dtype: torch.dtype = torch.bfloat16,
max_batch_size: int = 16, max_batch_size: int = 16,
): ):
global _model_param, _engine
if param_path is None: if param_path is None:
param_path = _project_root / "params" param_path = _project_root / "params"
if not param_path.exists(): if not param_path.exists():
raise FileNotFoundError(f"Parameter directory not found: {param_path}") raise FileNotFoundError(f"Parameter directory not found: {param_path}")
tokenizer = AutoTokenizer.from_pretrained(param_path) tokenizer = AutoTokenizer.from_pretrained(param_path)
_model_param = AutoModel.from_pretrained(param_path) _state.model_param = AutoModel.from_pretrained(param_path)
_model_param.to(device=device, dtype=dtype) _state.model_param.to(device=device, dtype=dtype)
logger.info(f"Model loaded on {device} with dtype {dtype}") logger.info(f"Model loaded on {device} with dtype {dtype}")
_engine = InferenceEngine( _state.engine = InferenceEngine(
model=_model_param, model=_state.model_param,
tokenizer=tokenizer, tokenizer=tokenizer,
max_batch_size=max_batch_size, max_batch_size=max_batch_size,
) )
logger.info(f"Inference engine initialized with max_batch_size={max_batch_size}") logger.info(f"Inference engine initialized with max_batch_size={max_batch_size}")
def _get_engine() -> InferenceEngine:
if _state.engine is None:
raise HTTPException(status_code=503, detail="Engine not initialized")
return _state.engine
class ChatMessage(BaseModel): class ChatMessage(BaseModel):
role: str role: str
content: str content: str
@ -121,30 +141,27 @@ class CompletionResponse(BaseModel):
async def health(): async def health():
return { return {
"status": "ok", "status": "ok",
"model_loaded": _model_param is not None, "model_loaded": _state.model_param is not None,
"engine_ready": _engine is not None, "engine_ready": _state.engine is not None,
} }
@app.get("/stats") @app.get("/stats")
async def get_stats(): async def get_stats():
if _engine is None: return _get_engine().get_stats()
raise HTTPException(status_code=503, detail="Engine not initialized")
return _engine.get_stats()
@app.post("/v1/chat/completions", response_model=CompletionResponse) @app.post("/v1/chat/completions", response_model=CompletionResponse)
async def chat_completion(request: ChatCompletionRequest): async def chat_completion(request: ChatCompletionRequest):
if _engine is None: engine = _get_engine()
raise HTTPException(status_code=503, detail="Engine not initialized")
prompt = _engine.tokenizer.apply_chat_template( prompt = engine.tokenizer.apply_chat_template(
[{"role": m.role, "content": m.content} for m in request.messages], [{"role": m.role, "content": m.content} for m in request.messages],
tokenize=False, tokenize=False,
) )
if request.stream: if request.stream:
agen = _engine.generate_async( agen = engine.generate_async(
prompt=prompt, prompt=prompt,
max_tokens=request.max_tokens, max_tokens=request.max_tokens,
temperature=request.temperature, temperature=request.temperature,
@ -163,7 +180,7 @@ async def chat_completion(request: ChatCompletionRequest):
headers={"Cache-Control": "no-cache", "Connection": "keep-alive"}, headers={"Cache-Control": "no-cache", "Connection": "keep-alive"},
) )
else: else:
result = _engine.generate( result = engine.generate(
prompt=prompt, prompt=prompt,
stream=False, stream=False,
max_tokens=request.max_tokens, max_tokens=request.max_tokens,
@ -198,8 +215,7 @@ async def generate(
max_len: int = 2048, max_len: int = 2048,
stream: bool = False, stream: bool = False,
): ):
if _engine is None: engine = _get_engine()
raise HTTPException(status_code=503, detail="Engine not initialized")
messages = [] messages = []
if history: if history:
@ -209,10 +225,10 @@ async def generate(
messages.append({"role": "assistant", "content": h[1]}) messages.append({"role": "assistant", "content": h[1]})
messages.append({"role": "user", "content": query}) messages.append({"role": "user", "content": query})
prompt = _engine.tokenizer.apply_chat_template(messages, tokenize=False) prompt = engine.tokenizer.apply_chat_template(messages, tokenize=False)
if stream: if stream:
agen = _engine.generate_async( agen = engine.generate_async(
prompt=prompt, prompt=prompt,
max_tokens=max_len, max_tokens=max_len,
temperature=temperature, temperature=temperature,
@ -226,7 +242,7 @@ async def generate(
return StreamingResponse(text_stream(), media_type="text/plain") return StreamingResponse(text_stream(), media_type="text/plain")
else: else:
result = _engine.generate( result = engine.generate(
prompt=prompt, prompt=prompt,
stream=False, stream=False,
max_tokens=max_len, max_tokens=max_len,

View File

@ -4,12 +4,13 @@ AutoModel base class for model loading and saving.
from contextlib import contextmanager from contextlib import contextmanager
from pathlib import Path from pathlib import Path
from typing import Dict, Self, Type, Union from typing import Self, Type, Union
import safetensors.torch as st import safetensors.torch as st
import torch.nn as nn import torch.nn as nn
from astrai.config import ModelConfig from astrai.config import ModelConfig
from astrai.factory import Registry
@contextmanager @contextmanager
@ -44,8 +45,7 @@ class AutoModel(nn.Module):
Provides model loading/saving and generation capabilities. Provides model loading/saving and generation capabilities.
""" """
# Model registry - stored as class attribute _registry = Registry()
_registry: Dict[str, Type["AutoModel"]] = {}
def __init__(self, config: ModelConfig): def __init__(self, config: ModelConfig):
super().__init__() super().__init__()
@ -63,7 +63,7 @@ class AutoModel(nn.Module):
""" """
def decorator(sub_cls: Type["AutoModel"]) -> Type["AutoModel"]: def decorator(sub_cls: Type["AutoModel"]) -> Type["AutoModel"]:
cls._registry[model_type.lower()] = sub_cls cls._registry.register(model_type.lower(), sub_cls)
return sub_cls return sub_cls
return decorator return decorator
@ -72,12 +72,12 @@ class AutoModel(nn.Module):
def get_model_class(cls, model_type: str) -> Type["AutoModel"]: def get_model_class(cls, model_type: str) -> Type["AutoModel"]:
"""Get model class by model_type string.""" """Get model class by model_type string."""
model_type = model_type.lower() model_type = model_type.lower()
if model_type not in cls._registry: if not cls._registry.contains(model_type):
available = list(cls._registry.keys()) available = cls._registry.list_names()
raise ValueError( raise ValueError(
f"Unknown model_type: {model_type}. Available: {available}" f"Unknown model_type: {model_type}. Available: {available}"
) )
return cls._registry[model_type] return cls._registry.get(model_type)
@classmethod @classmethod
def from_pretrained( def from_pretrained(
@ -96,14 +96,8 @@ class AutoModel(nn.Module):
else: else:
raise FileNotFoundError(f"Config file not found: {config_path}") raise FileNotFoundError(f"Config file not found: {config_path}")
# If called from base class, use model_type to determine actual model class model_type = config.model_type or "transformer"
if cls is AutoModel: actual_cls = cls.get_model_class(model_type)
model_type = config.model_type or "transformer"
actual_cls = cls.get_model_class(model_type)
else:
raise ValueError(
f"Cannot call from_pretrained() on subclass {cls.__name__}"
)
with _disable_random_init(enable=disable_random_init): with _disable_random_init(enable=disable_random_init):
model = actual_cls(config) model = actual_cls(config)

View File

@ -34,66 +34,60 @@ class TrainContext:
class TrainContextBuilder: class TrainContextBuilder:
def __init__(self, config: TrainConfig): def __init__(self, config: TrainConfig):
self.config = config self.config = config
self._context = TrainContext( self._checkpoint: Optional[Checkpoint] = None
model=config.model,
def with_checkpoint(self, checkpoint: Optional[Checkpoint]) -> Self:
self._checkpoint = checkpoint
return self
def build(self) -> TrainContext:
context = TrainContext(
model=self.config.model,
world_size=get_world_size(), world_size=get_world_size(),
rank=get_rank(), rank=get_rank(),
) )
device = get_current_device() device = get_current_device()
self._context.model = self._context.model.to(device=device) context.model = context.model.to(device=device)
if self.config.nprocs > 1: if self.config.nprocs > 1 and self.config.parallel_wrapper:
fn = self.config.parallel_wrapper context.model = self.config.parallel_wrapper(context.model)
self._context.model = fn(self._context.model)
self._context.optimizer = self.config.optimizer_fn(self._context.model) if self._checkpoint is not None:
self._context.scheduler = self.config.scheduler_fn(self._context.optimizer) context.epoch = max(self._checkpoint.epoch, self.config.start_epoch)
context.iteration = max(self._checkpoint.iteration, self.config.start_batch)
def with_checkpoint(self, checkpoint: Optional[Checkpoint]) -> Self: context.model.load_state_dict(self._checkpoint.state_dict)
if checkpoint is None: context.checkpoint = self._checkpoint
checkpoint = Checkpoint(
state_dict=self._context.model.state_dict(),
)
else: else:
# resume from the assigned checkpoint or assigned iteration context.checkpoint = Checkpoint(
self._context.epoch = max(checkpoint.epoch, self.config.start_epoch) state_dict=context.model.state_dict(),
self._context.iteration = max(checkpoint.iteration, self.config.start_batch) )
self._context.model.load_state_dict(checkpoint.state_dict)
self._context.checkpoint = checkpoint context.optimizer = self.config.optimizer_fn(context.model)
return self context.scheduler = self.config.scheduler_fn(context.optimizer)
def with_dataloader(self) -> Self: cfg = self.config
# fix: change batch level iteration to sample level offset sampler_offset = context.iteration * cfg.batch_size
config = self.config sampler = ResumableDistributedSampler(
sampler_offset = self._context.iteration * config.batch_size data_source=cfg.dataset,
resumeable_sampler = ResumableDistributedSampler( start_epoch=context.epoch,
data_source=config.dataset,
start_epoch=self._context.epoch,
start_iter=sampler_offset, start_iter=sampler_offset,
seed=config.random_seed, seed=cfg.random_seed,
)
context.dataloader = DataLoader(
cfg.dataset,
batch_size=cfg.batch_size,
sampler=sampler,
num_workers=cfg.num_workers,
pin_memory=cfg.pin_memory,
prefetch_factor=cfg.prefetch_factor,
) )
dataloader = DataLoader( context.strategy = StrategyFactory.create(
config.dataset, model=context.model,
batch_size=config.batch_size,
sampler=resumeable_sampler,
num_workers=config.num_workers,
pin_memory=config.pin_memory,
prefetch_factor=config.prefetch_factor,
)
self._context.dataloader = dataloader
return self
def with_strategy(self) -> Self:
self._context.strategy = StrategyFactory.create(
model=self._context.model,
train_type=self.config.strategy, train_type=self.config.strategy,
device=get_current_device(), device=device,
**self.config.extra_kwargs, **self.config.extra_kwargs,
) )
return self
def build(self) -> TrainContext: return context
return self._context

View File

@ -35,11 +35,7 @@ class Trainer:
def _build_context(self, checkpoint: Optional[Checkpoint]) -> TrainContext: def _build_context(self, checkpoint: Optional[Checkpoint]) -> TrainContext:
return ( return (
TrainContextBuilder(self.train_config) TrainContextBuilder(self.train_config).with_checkpoint(checkpoint).build()
.with_checkpoint(checkpoint)
.with_dataloader()
.with_strategy()
.build()
) )
def _call_callbacks(self, method_name: str, context: TrainContext): def _call_callbacks(self, method_name: str, context: TrainContext):

View File

@ -53,5 +53,5 @@ def mock_engine():
@pytest.fixture @pytest.fixture
def loaded_model(mock_model_param, monkeypatch): def loaded_model(mock_model_param, monkeypatch):
"""Simulate that the model is loaded.""" """Simulate that the model is loaded."""
monkeypatch.setattr("astrai.inference.server._model_param", mock_model_param) monkeypatch.setattr("astrai.inference.server._state.model_param", mock_model_param)
return mock_model_param return mock_model_param

View File

@ -5,8 +5,8 @@ import pytest
def test_health_no_model(client, monkeypatch): def test_health_no_model(client, monkeypatch):
"""GET /health should return 200 even when model not loaded.""" """GET /health should return 200 even when model not loaded."""
monkeypatch.setattr("astrai.inference.server._model_param", None) monkeypatch.setattr("astrai.inference.server._state.model_param", None)
monkeypatch.setattr("astrai.inference.server._engine", None) monkeypatch.setattr("astrai.inference.server._state.engine", None)
response = client.get("/health") response = client.get("/health")
assert response.status_code == 200 assert response.status_code == 200
data = response.json() data = response.json()
@ -17,7 +17,7 @@ def test_health_no_model(client, monkeypatch):
def test_health_with_model(client, loaded_model, mock_engine, monkeypatch): def test_health_with_model(client, loaded_model, mock_engine, monkeypatch):
"""GET /health should return 200 when model is loaded.""" """GET /health should return 200 when model is loaded."""
monkeypatch.setattr("astrai.inference.server._engine", mock_engine) monkeypatch.setattr("astrai.inference.server._state.engine", mock_engine)
response = client.get("/health") response = client.get("/health")
assert response.status_code == 200 assert response.status_code == 200
data = response.json() data = response.json()
@ -28,7 +28,7 @@ def test_health_with_model(client, loaded_model, mock_engine, monkeypatch):
def test_generate_non_stream(client, loaded_model, mock_engine, monkeypatch): def test_generate_non_stream(client, loaded_model, mock_engine, monkeypatch):
"""POST /generate with stream=false should return JSON response.""" """POST /generate with stream=false should return JSON response."""
monkeypatch.setattr("astrai.inference.server._engine", mock_engine) monkeypatch.setattr("astrai.inference.server._state.engine", mock_engine)
response = client.post( response = client.post(
"/generate", "/generate",
params={ params={
@ -54,7 +54,7 @@ def test_generate_stream(client, loaded_model, mock_engine, monkeypatch):
yield "chunk2" yield "chunk2"
mock_engine.generate.return_value = stream_gen() mock_engine.generate.return_value = stream_gen()
monkeypatch.setattr("astrai.inference.server._engine", mock_engine) monkeypatch.setattr("astrai.inference.server._state.engine", mock_engine)
response = client.post( response = client.post(
"/generate", "/generate",
params={ params={
@ -78,7 +78,7 @@ def test_generate_stream(client, loaded_model, mock_engine, monkeypatch):
def test_chat_completions_non_stream(client, loaded_model, mock_engine, monkeypatch): def test_chat_completions_non_stream(client, loaded_model, mock_engine, monkeypatch):
"""POST /v1/chat/completions with stream=false returns OpenAIstyle JSON.""" """POST /v1/chat/completions with stream=false returns OpenAIstyle JSON."""
mock_engine.generate.return_value = "Assistant reply" mock_engine.generate.return_value = "Assistant reply"
monkeypatch.setattr("astrai.inference.server._engine", mock_engine) monkeypatch.setattr("astrai.inference.server._state.engine", mock_engine)
response = client.post( response = client.post(
"/v1/chat/completions", "/v1/chat/completions",
json={ json={
@ -106,7 +106,7 @@ def test_chat_completions_stream(client, loaded_model, mock_engine, monkeypatch)
yield "[DONE]" yield "[DONE]"
mock_engine.generate_async.return_value = async_gen() mock_engine.generate_async.return_value = async_gen()
monkeypatch.setattr("astrai.inference.server._engine", mock_engine) monkeypatch.setattr("astrai.inference.server._state.engine", mock_engine)
response = client.post( response = client.post(
"/v1/chat/completions", "/v1/chat/completions",
json={ json={
@ -132,7 +132,7 @@ def test_chat_completions_stream(client, loaded_model, mock_engine, monkeypatch)
def test_generate_with_history(client, loaded_model, mock_engine, monkeypatch): def test_generate_with_history(client, loaded_model, mock_engine, monkeypatch):
"""POST /generate with history parameter.""" """POST /generate with history parameter."""
monkeypatch.setattr("astrai.inference.server._engine", mock_engine) monkeypatch.setattr("astrai.inference.server._state.engine", mock_engine)
response = client.post( response = client.post(
"/generate", "/generate",
params={ params={