529 lines
16 KiB
Python
529 lines
16 KiB
Python
"""
|
|
OpenAI / Anthropic-compatible chat completion server backed by continuous-batching inference.
|
|
"""
|
|
|
|
import json
|
|
import logging
|
|
import time
|
|
import uuid
|
|
from contextlib import asynccontextmanager
|
|
from pathlib import Path
|
|
from typing import Any, Dict, List, Optional, Union
|
|
|
|
import torch
|
|
import uvicorn
|
|
from fastapi import FastAPI, HTTPException
|
|
from fastapi.responses import PlainTextResponse, StreamingResponse
|
|
from pydantic import BaseModel, Field
|
|
|
|
from astrai.inference.engine import InferenceEngine
|
|
from astrai.model import AutoModel
|
|
from astrai.tokenize import AutoTokenizer
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
_project_root = Path(__file__).parent.parent.parent
|
|
|
|
|
|
class ServerState:
|
|
def __init__(self):
|
|
self.engine: Optional[InferenceEngine] = None
|
|
self.config: Dict[str, Any] = {
|
|
"device": "cuda",
|
|
"dtype": torch.bfloat16,
|
|
"param_path": None,
|
|
"max_batch_size": 16,
|
|
}
|
|
|
|
|
|
_state = ServerState()
|
|
|
|
|
|
class ChatMessage(BaseModel):
|
|
role: str
|
|
content: str
|
|
|
|
|
|
class ChatCompletionRequest(BaseModel):
|
|
"""OpenAI Chat Completion API request body."""
|
|
|
|
model: str = "astrai"
|
|
messages: List[ChatMessage]
|
|
temperature: Optional[float] = Field(default=1.0, ge=0.0, le=2.0)
|
|
top_p: Optional[float] = Field(default=1.0, ge=0.0, le=1.0)
|
|
top_k: Optional[int] = Field(default=50, ge=1)
|
|
stream: Optional[bool] = False
|
|
stop: Optional[Union[str, List[str]]] = None
|
|
max_tokens: Optional[int] = Field(default=2048, ge=1)
|
|
n: Optional[int] = Field(default=1, ge=1)
|
|
presence_penalty: Optional[float] = Field(default=0.0, ge=-2.0, le=2.0)
|
|
frequency_penalty: Optional[float] = Field(default=0.0, ge=-2.0, le=2.0)
|
|
logit_bias: Optional[Dict[int, float]] = None
|
|
user: Optional[str] = None
|
|
|
|
|
|
class AnthropicMessage(BaseModel):
|
|
role: str
|
|
content: Union[str, List[Dict[str, Any]]]
|
|
|
|
|
|
class MessagesRequest(BaseModel):
|
|
"""Anthropic Messages API request body."""
|
|
|
|
model: str = "astrai"
|
|
max_tokens: int = Field(default=1024, ge=1)
|
|
messages: List[AnthropicMessage]
|
|
system: Optional[str] = None
|
|
temperature: Optional[float] = Field(default=1.0, ge=0.0, le=2.0)
|
|
top_p: Optional[float] = Field(default=1.0, ge=0.0, le=1.0)
|
|
top_k: Optional[int] = Field(default=50, ge=1)
|
|
stream: Optional[bool] = False
|
|
stop_sequences: Optional[List[str]] = None
|
|
|
|
|
|
def configure_server(
|
|
device: str = "cuda",
|
|
dtype: torch.dtype = torch.bfloat16,
|
|
param_path: Optional[Path] = None,
|
|
max_batch_size: int = 16,
|
|
):
|
|
_state.config.update(
|
|
device=device,
|
|
dtype=dtype,
|
|
param_path=param_path,
|
|
max_batch_size=max_batch_size,
|
|
max_queue_size=64,
|
|
request_timeout=60.0,
|
|
)
|
|
|
|
|
|
@asynccontextmanager
|
|
async def lifespan(app: FastAPI):
|
|
try:
|
|
load_model(
|
|
param_path=_state.config["param_path"],
|
|
device=_state.config["device"],
|
|
dtype=_state.config["dtype"],
|
|
max_batch_size=_state.config["max_batch_size"],
|
|
)
|
|
except Exception as e:
|
|
logger.error(f"Failed to load model: {e}")
|
|
raise
|
|
yield
|
|
if _state.engine:
|
|
_state.engine.shutdown()
|
|
logger.info("Inference engine shutdown complete")
|
|
|
|
|
|
app = FastAPI(title="AstrAI Inference Server", version="0.2.0", lifespan=lifespan)
|
|
|
|
|
|
def load_model(
|
|
param_path: Optional[Path] = None,
|
|
device: str = "cuda",
|
|
dtype: torch.dtype = torch.bfloat16,
|
|
max_batch_size: int = 16,
|
|
):
|
|
if param_path is None:
|
|
param_path = _project_root / "params"
|
|
if not param_path.exists():
|
|
raise FileNotFoundError(f"Parameter directory not found: {param_path}")
|
|
|
|
tokenizer = AutoTokenizer.from_pretrained(param_path)
|
|
model = AutoModel.from_pretrained(param_path)
|
|
model.to(device=device, dtype=dtype)
|
|
logger.info(f"Model loaded on {device} with dtype {dtype}")
|
|
|
|
_state.engine = InferenceEngine(
|
|
model=model,
|
|
tokenizer=tokenizer,
|
|
max_batch_size=max_batch_size,
|
|
)
|
|
logger.info(f"Inference engine initialized with max_batch_size={max_batch_size}")
|
|
|
|
|
|
def _get_engine() -> InferenceEngine:
|
|
if _state.engine is None:
|
|
raise HTTPException(status_code=503, detail="Engine not initialized")
|
|
return _state.engine
|
|
|
|
|
|
def _make_chunk(
|
|
delta: Dict[str, str],
|
|
finish_reason: Optional[str] = None,
|
|
*,
|
|
resp_id: str,
|
|
created: int,
|
|
model: str,
|
|
index: int = 0,
|
|
) -> str:
|
|
"""Build a single SSE ``data:`` chunk matching OpenAI streaming format."""
|
|
data = {
|
|
"id": resp_id,
|
|
"object": "chat.completion.chunk",
|
|
"created": created,
|
|
"model": model,
|
|
"choices": [
|
|
{
|
|
"index": index,
|
|
"delta": delta,
|
|
"finish_reason": finish_reason,
|
|
}
|
|
],
|
|
}
|
|
return f"data: {json.dumps(data, ensure_ascii=False)}\n\n"
|
|
|
|
|
|
@app.get("/health")
|
|
async def health():
|
|
return {
|
|
"status": "ok",
|
|
"model_loaded": _state.engine is not None,
|
|
}
|
|
|
|
|
|
@app.get("/stats")
|
|
async def get_stats():
|
|
return _get_engine().get_stats()
|
|
|
|
|
|
@app.get("/metrics")
|
|
async def metrics():
|
|
s = _get_engine().get_stats()
|
|
lines = [
|
|
"# HELP astrai_requests_total Total requests received",
|
|
"# TYPE astrai_requests_total counter",
|
|
f'astrai_requests_total{{status="accepted"}} {s["total_requests"]}',
|
|
f'astrai_requests_total{{status="rejected"}} {s["total_rejected"]}',
|
|
f'astrai_requests_total{{status="timeout"}} {s["total_timeouts"]}',
|
|
"# HELP astrai_tokens_generated Total generated tokens",
|
|
"# TYPE astrai_tokens_generated counter",
|
|
f"astrai_tokens_generated {s['total_tokens']}",
|
|
"# HELP astrai_active_tasks Currently active tasks",
|
|
"# TYPE astrai_active_tasks gauge",
|
|
f"astrai_active_tasks {s['active_tasks']}",
|
|
"# HELP astrai_queue_depth Waiting queue depth",
|
|
"# TYPE astrai_queue_depth gauge",
|
|
f"astrai_queue_depth {s['waiting_queue']}",
|
|
"# HELP astrai_request_latency_seconds Request latency quantiles",
|
|
"# TYPE astrai_request_latency_seconds gauge",
|
|
f'astrai_request_latency_seconds{{quantile="0.5"}} {s["latency_p50"]:.3f}',
|
|
f'astrai_request_latency_seconds{{quantile="0.95"}} {s["latency_p95"]:.3f}',
|
|
f'astrai_request_latency_seconds{{quantile="0.99"}} {s["latency_p99"]:.3f}',
|
|
"# HELP astrai_cache_hit_rate Prefix cache hit ratio",
|
|
"# TYPE astrai_cache_hit_rate gauge",
|
|
f"astrai_cache_hit_rate {s['cache_hit_rate']:.3f}",
|
|
"# HELP astrai_cache_lookups_total Prefix cache page lookups",
|
|
"# TYPE astrai_cache_lookups_total counter",
|
|
f'astrai_cache_lookups_total{{result="hit"}} {s["cache_hits"]}',
|
|
f'astrai_cache_lookups_total{{result="miss"}} {s["cache_misses"]}',
|
|
]
|
|
return PlainTextResponse("\n".join(lines) + "\n")
|
|
|
|
|
|
@app.post("/v1/chat/completions")
|
|
async def chat_completion(request: ChatCompletionRequest):
|
|
"""OpenAI-compatible chat completion endpoint (streaming + non-streaming)."""
|
|
engine = _get_engine()
|
|
resp_id = f"chatcmpl-{uuid.uuid4().hex[:12]}"
|
|
created = int(time.time())
|
|
model = request.model
|
|
|
|
prompt = engine.tokenizer.apply_chat_template(
|
|
[{"role": m.role, "content": m.content} for m in request.messages],
|
|
tokenize=False,
|
|
)
|
|
prompt_tokens = len(engine.tokenizer.encode(prompt))
|
|
|
|
if request.stream:
|
|
try:
|
|
agen = engine.generate_async(
|
|
prompt=prompt,
|
|
max_tokens=request.max_tokens,
|
|
temperature=request.temperature,
|
|
top_p=request.top_p,
|
|
top_k=request.top_k,
|
|
)
|
|
except RuntimeError as e:
|
|
raise HTTPException(status_code=503, detail=str(e))
|
|
|
|
async def event_stream():
|
|
yield _make_chunk(
|
|
{"role": "assistant"},
|
|
finish_reason=None,
|
|
resp_id=resp_id,
|
|
created=created,
|
|
model=model,
|
|
)
|
|
|
|
completion_tokens = 0
|
|
async for token in agen:
|
|
yield _make_chunk(
|
|
{"content": token},
|
|
finish_reason=None,
|
|
resp_id=resp_id,
|
|
created=created,
|
|
model=model,
|
|
)
|
|
completion_tokens += 1
|
|
|
|
yield _make_chunk(
|
|
{},
|
|
finish_reason="stop",
|
|
resp_id=resp_id,
|
|
created=created,
|
|
model=model,
|
|
)
|
|
|
|
usage = {
|
|
"prompt_tokens": prompt_tokens,
|
|
"completion_tokens": completion_tokens,
|
|
"total_tokens": prompt_tokens + completion_tokens,
|
|
}
|
|
yield f"data: {json.dumps(usage, ensure_ascii=False)}\n\n"
|
|
yield "data: [DONE]\n\n"
|
|
|
|
return StreamingResponse(
|
|
event_stream(),
|
|
media_type="text/event-stream",
|
|
headers={"Cache-Control": "no-cache", "Connection": "keep-alive"},
|
|
)
|
|
|
|
completion_tokens = 0
|
|
chunks: List[str] = []
|
|
try:
|
|
agen = engine.generate_async(
|
|
prompt=prompt,
|
|
max_tokens=request.max_tokens,
|
|
temperature=request.temperature,
|
|
top_p=request.top_p,
|
|
top_k=request.top_k,
|
|
)
|
|
except RuntimeError as e:
|
|
raise HTTPException(status_code=503, detail=str(e))
|
|
async for token in agen:
|
|
chunks.append(token)
|
|
completion_tokens += 1
|
|
content = "".join(chunks)
|
|
|
|
return {
|
|
"id": resp_id,
|
|
"object": "chat.completion",
|
|
"created": created,
|
|
"model": model,
|
|
"choices": [
|
|
{
|
|
"index": 0,
|
|
"message": {"role": "assistant", "content": content},
|
|
"finish_reason": "stop",
|
|
}
|
|
],
|
|
"usage": {
|
|
"prompt_tokens": prompt_tokens,
|
|
"completion_tokens": completion_tokens,
|
|
"total_tokens": prompt_tokens + completion_tokens,
|
|
},
|
|
}
|
|
|
|
|
|
def _make_anthropic_sse(event: str, data: Dict[str, Any]) -> str:
|
|
return f"event: {event}\ndata: {json.dumps(data, ensure_ascii=False)}\n\n"
|
|
|
|
|
|
def _check_stop_sequence(text: str, stop_sequences: List[str]) -> Optional[str]:
|
|
for seq in stop_sequences:
|
|
if seq and seq in text:
|
|
return seq
|
|
return None
|
|
|
|
|
|
def _extract_text_content(content: Union[str, List[Dict[str, Any]]]) -> str:
|
|
if isinstance(content, str):
|
|
return content
|
|
if isinstance(content, list):
|
|
for block in content:
|
|
if isinstance(block, dict) and block.get("type") == "text":
|
|
return block.get("text", "")
|
|
return ""
|
|
|
|
|
|
def _build_anthropic_messages(
|
|
messages: List[AnthropicMessage], system: Optional[str]
|
|
) -> List[Dict[str, str]]:
|
|
result: List[Dict[str, str]] = []
|
|
if system:
|
|
result.append({"role": "system", "content": system})
|
|
for m in messages:
|
|
content = _extract_text_content(m.content)
|
|
if content:
|
|
result.append({"role": m.role, "content": content})
|
|
return result
|
|
|
|
|
|
@app.post("/v1/messages")
|
|
async def create_message(request: MessagesRequest):
|
|
"""Anthropic-compatible Messages API endpoint (streaming + non-streaming)."""
|
|
engine = _get_engine()
|
|
resp_id = f"msg_{uuid.uuid4().hex[:24]}"
|
|
model = request.model
|
|
|
|
chat_messages = _build_anthropic_messages(request.messages, request.system)
|
|
prompt = engine.tokenizer.apply_chat_template(chat_messages, tokenize=False)
|
|
prompt_tokens = len(engine.tokenizer.encode(prompt))
|
|
|
|
stop_sequences = request.stop_sequences or []
|
|
|
|
if request.stream:
|
|
agen = engine.generate_async(
|
|
prompt=prompt,
|
|
max_tokens=request.max_tokens,
|
|
temperature=request.temperature,
|
|
top_p=request.top_p,
|
|
top_k=request.top_k,
|
|
)
|
|
|
|
async def event_stream():
|
|
yield _make_anthropic_sse(
|
|
"message_start",
|
|
{
|
|
"type": "message_start",
|
|
"message": {
|
|
"id": resp_id,
|
|
"type": "message",
|
|
"role": "assistant",
|
|
"model": model,
|
|
"content": [],
|
|
"usage": {"input_tokens": prompt_tokens},
|
|
},
|
|
},
|
|
)
|
|
|
|
yield _make_anthropic_sse(
|
|
"content_block_start",
|
|
{
|
|
"type": "content_block_start",
|
|
"index": 0,
|
|
"content_block": {"type": "text", "text": ""},
|
|
},
|
|
)
|
|
|
|
completion_tokens = 0
|
|
accumulated = ""
|
|
stopped_seq: Optional[str] = None
|
|
async for token in agen:
|
|
accumulated += token
|
|
completion_tokens += 1
|
|
|
|
matched = _check_stop_sequence(accumulated, stop_sequences)
|
|
if matched:
|
|
text = accumulated[: accumulated.rfind(matched)]
|
|
stopped_seq = matched
|
|
if text:
|
|
yield _make_anthropic_sse(
|
|
"content_block_delta",
|
|
{
|
|
"type": "content_block_delta",
|
|
"index": 0,
|
|
"delta": {"type": "text_delta", "text": text},
|
|
},
|
|
)
|
|
break
|
|
|
|
yield _make_anthropic_sse(
|
|
"content_block_delta",
|
|
{
|
|
"type": "content_block_delta",
|
|
"index": 0,
|
|
"delta": {"type": "text_delta", "text": token},
|
|
},
|
|
)
|
|
|
|
yield _make_anthropic_sse(
|
|
"content_block_stop",
|
|
{"type": "content_block_stop", "index": 0},
|
|
)
|
|
|
|
stop_reason = "stop_sequence" if stopped_seq else "end_turn"
|
|
yield _make_anthropic_sse(
|
|
"message_delta",
|
|
{
|
|
"type": "message_delta",
|
|
"delta": {"stop_reason": stop_reason, "stop_sequence": stopped_seq},
|
|
"usage": {"output_tokens": completion_tokens},
|
|
},
|
|
)
|
|
|
|
yield _make_anthropic_sse(
|
|
"message_stop",
|
|
{"type": "message_stop"},
|
|
)
|
|
|
|
return StreamingResponse(
|
|
event_stream(),
|
|
media_type="text/event-stream",
|
|
headers={"Cache-Control": "no-cache", "Connection": "keep-alive"},
|
|
)
|
|
|
|
completion_tokens = 0
|
|
chunks: List[str] = []
|
|
agen = engine.generate_async(
|
|
prompt=prompt,
|
|
max_tokens=request.max_tokens,
|
|
temperature=request.temperature,
|
|
top_p=request.top_p,
|
|
top_k=request.top_k,
|
|
)
|
|
stopped_seq: Optional[str] = None
|
|
accumulated = ""
|
|
async for token in agen:
|
|
chunks.append(token)
|
|
completion_tokens += 1
|
|
accumulated += token
|
|
matched = _check_stop_sequence(accumulated, stop_sequences)
|
|
if matched:
|
|
stopped_seq = matched
|
|
break
|
|
|
|
content = "".join(chunks)
|
|
if stopped_seq:
|
|
idx = content.rfind(stopped_seq)
|
|
if idx != -1:
|
|
content = content[:idx]
|
|
|
|
return {
|
|
"id": resp_id,
|
|
"type": "message",
|
|
"role": "assistant",
|
|
"model": model,
|
|
"content": [{"type": "text", "text": content}],
|
|
"stop_reason": "stop_sequence" if stopped_seq else "end_turn",
|
|
"stop_sequence": stopped_seq,
|
|
"usage": {
|
|
"input_tokens": prompt_tokens,
|
|
"output_tokens": completion_tokens,
|
|
},
|
|
}
|
|
|
|
|
|
def run_server(
|
|
host: str = "0.0.0.0",
|
|
port: int = 8000,
|
|
reload: bool = False,
|
|
device: str = "cuda",
|
|
dtype: torch.dtype = torch.bfloat16,
|
|
param_path: Optional[Path] = None,
|
|
max_batch_size: int = 16,
|
|
):
|
|
configure_server(
|
|
device=device,
|
|
dtype=dtype,
|
|
param_path=param_path,
|
|
max_batch_size=max_batch_size,
|
|
)
|
|
uvicorn.run(
|
|
"astrai.inference.server:app",
|
|
host=host,
|
|
port=port,
|
|
reload=reload,
|
|
)
|