refactor: 推理协议层重构为策略/建造者模式

- ProtocolHandler 改为具体类,格式化委托给 ResponseBuilder
- 新增 api/protocols/ 目录,含 OpenAIResponseBuilder、AnthropicResponseBuilder
- GenContext、StopInfo 参数对象替代 StreamContext
- 消除 Builder 的实例可变状态(accumulated、_yielded)
- SSE 工具和停止检测收归 ProtocolHandler 统一管理
- prepare() 方法合并原来的 build_prompt、create_response_id
- 参数校验去重:仅 GenerationRequest.init 负责校验
- Prefill 阶段提前短路完全命中的缓存任务
This commit is contained in:
ViperEkura 2026-05-26 00:07:12 +08:00
parent 737585a32a
commit 47c37e4876
8 changed files with 370 additions and 400 deletions

View File

@ -2,24 +2,26 @@
Layers: Layers:
- core/: Core inference loop (cache, executor, scheduler, task) - core/: Core inference loop (cache, executor, scheduler, task)
- api/: HTTP protocol handlers (OpenAI, Anthropic) - api/: HTTP orchestration (ProtocolHandler, server)
- protocols/: Response builders (OpenAI, Anthropic)
- transport/: SSE transport utilities
- engine.py: Facade (InferenceEngine), Value Object (GenerationRequest) - engine.py: Facade (InferenceEngine), Value Object (GenerationRequest)
- sample.py: Strategy pattern (TemperatureStrategy, TopKStrategy, TopPStrategy) - sample.py: Strategy pattern (TemperatureStrategy, TopKStrategy, TopPStrategy)
""" """
from astrai.inference.api import ( from astrai.inference.api import (
AnthropicHandler,
AnthropicMessage, AnthropicMessage,
ChatCompletionRequest, ChatCompletionRequest,
ChatMessage, ChatMessage,
GenContext,
MessagesRequest, MessagesRequest,
OpenAIHandler,
ProtocolHandler, ProtocolHandler,
StopChecker, StopChecker,
StreamContext,
app, app,
run_server, run_server,
) )
from astrai.inference.api.anthropic import AnthropicResponseBuilder
from astrai.inference.api.openai import OpenAIResponseBuilder
from astrai.inference.core import ( from astrai.inference.core import (
STOP, STOP,
Allocator, Allocator,
@ -36,10 +38,7 @@ from astrai.inference.core import (
TaskTable, TaskTable,
page_hash, page_hash,
) )
from astrai.inference.engine import ( from astrai.inference.engine import GenerationRequest, InferenceEngine
GenerationRequest,
InferenceEngine,
)
from astrai.inference.sample import ( from astrai.inference.sample import (
BaseSamplingStrategy, BaseSamplingStrategy,
SamplingPipeline, SamplingPipeline,
@ -50,17 +49,14 @@ from astrai.inference.sample import (
) )
__all__ = [ __all__ = [
# Engine / Requests
"InferenceEngine", "InferenceEngine",
"GenerationRequest", "GenerationRequest",
# Core scheduler
"InferenceScheduler", "InferenceScheduler",
"Executor", "Executor",
"STOP", "STOP",
"Task", "Task",
"TaskManager", "TaskManager",
"TaskStatus", "TaskStatus",
# Core cache
"Allocator", "Allocator",
"KVCache", "KVCache",
"KvcacheView", "KvcacheView",
@ -69,20 +65,17 @@ __all__ = [
"Storage", "Storage",
"TaskTable", "TaskTable",
"page_hash", "page_hash",
# Sampling (Strategy pattern)
"sample", "sample",
"BaseSamplingStrategy", "BaseSamplingStrategy",
"TemperatureStrategy", "TemperatureStrategy",
"TopKStrategy", "TopKStrategy",
"TopPStrategy", "TopPStrategy",
"SamplingPipeline", "SamplingPipeline",
# Protocol
"ProtocolHandler", "ProtocolHandler",
"StopChecker", "StopChecker",
"StreamContext", "GenContext",
"AnthropicHandler", "OpenAIResponseBuilder",
"OpenAIHandler", "AnthropicResponseBuilder",
# Server
"ChatMessage", "ChatMessage",
"ChatCompletionRequest", "ChatCompletionRequest",
"AnthropicMessage", "AnthropicMessage",

View File

@ -1,12 +1,6 @@
"""Inference API: protocol handlers and FastAPI server.""" """Inference API: protocol handler, stop checker, and FastAPI server."""
from astrai.inference.api.protocol import ( from astrai.inference.api.protocol import GenContext, ProtocolHandler, StopChecker
AnthropicHandler,
OpenAIHandler,
ProtocolHandler,
StopChecker,
StreamContext,
)
from astrai.inference.api.server import ( from astrai.inference.api.server import (
AnthropicMessage, AnthropicMessage,
ChatCompletionRequest, ChatCompletionRequest,
@ -17,11 +11,9 @@ from astrai.inference.api.server import (
) )
__all__ = [ __all__ = [
"AnthropicHandler",
"OpenAIHandler",
"ProtocolHandler", "ProtocolHandler",
"StopChecker", "StopChecker",
"StreamContext", "GenContext",
"AnthropicMessage", "AnthropicMessage",
"ChatCompletionRequest", "ChatCompletionRequest",
"ChatMessage", "ChatMessage",

View File

@ -0,0 +1,140 @@
"""Anthropic message completion response builder."""
import uuid
from typing import Any, Dict, List, Optional, Tuple, Union
from pydantic import BaseModel
from astrai.inference.api.protocol import (
GenContext,
ResponseBuilder,
StopInfo,
sse_event,
)
from astrai.inference.engine import InferenceEngine
def _extract_text(content: Union[str, List[Dict[str, Any]]]) -> str:
if isinstance(content, str):
return content
if isinstance(content, list):
for block in content:
if isinstance(block, dict) and block.get("type") == "text":
return block.get("text", "")
return ""
class AnthropicResponseBuilder(ResponseBuilder):
def prepare(
self, request: BaseModel, engine: InferenceEngine
) -> Tuple[str, GenContext, List[str]]:
messages: List[Dict[str, str]] = []
system = getattr(request, "system", None)
if system:
messages.append({"role": "system", "content": system})
for m in request.messages:
text = _extract_text(m.content)
if text:
messages.append({"role": m.role, "content": text})
prompt = engine.tokenizer.apply_chat_template(messages, tokenize=False)
ctx = GenContext(
resp_id=f"msg_{uuid.uuid4().hex[:24]}",
created=0,
model=request.model,
prompt_tokens=0,
)
stop_sequences = getattr(request, "stop_sequences", None) or []
return prompt, ctx, stop_sequences
def format_stream_start(self, ctx: GenContext) -> List[str]:
return [
sse_event(
{
"type": "message_start",
"message": {
"id": ctx.resp_id,
"type": "message",
"role": "assistant",
"model": ctx.model,
"content": [],
"usage": {"input_tokens": ctx.prompt_tokens},
},
},
event="message_start",
),
sse_event(
{
"type": "content_block_start",
"index": 0,
"content_block": {"type": "text", "text": ""},
},
event="content_block_start",
),
]
def format_chunk(self, token: str) -> str:
return sse_event(
{
"type": "content_block_delta",
"index": 0,
"delta": {"type": "text_delta", "text": token},
},
event="content_block_delta",
)
def format_stream_end(self, ctx: GenContext, stop: StopInfo) -> List[str]:
events: List[str] = []
if stop.matched:
trimmed = stop.body[: stop.body.rfind(stop.matched)]
unyielded = trimmed[len(stop.yielded) :]
if unyielded:
events.append(
sse_event(
{
"type": "content_block_delta",
"index": 0,
"delta": {"type": "text_delta", "text": unyielded},
},
event="content_block_delta",
)
)
events.append(
sse_event(
{"type": "content_block_stop", "index": 0},
event="content_block_stop",
)
)
events.append(
sse_event(
{
"type": "message_delta",
"delta": {
"stop_reason": "stop_sequence" if stop.matched else "end_turn",
"stop_sequence": stop.matched,
},
"usage": {"output_tokens": ctx.completion_tokens},
},
event="message_delta",
)
)
events.append(sse_event({"type": "message_stop"}, event="message_stop"))
return events
def format_response(
self, ctx: GenContext, content: str, stop: StopInfo
) -> Dict[str, Any]:
if stop.matched:
content = content[: content.rfind(stop.matched)]
return {
"id": ctx.resp_id,
"type": "message",
"role": "assistant",
"model": ctx.model,
"content": [{"type": "text", "text": content}],
"stop_reason": "stop_sequence" if stop.matched else "end_turn",
"stop_sequence": stop.matched,
"usage": {
"input_tokens": ctx.prompt_tokens,
"output_tokens": ctx.completion_tokens,
},
}

View File

@ -0,0 +1,111 @@
"""OpenAI chat completion response builder."""
import uuid
from typing import Any, Dict, List, Optional, Tuple
from pydantic import BaseModel
from astrai.inference.api.protocol import (
GenContext,
ResponseBuilder,
StopInfo,
sse_event,
)
from astrai.inference.engine import InferenceEngine
class OpenAIResponseBuilder(ResponseBuilder):
def prepare(
self, request: BaseModel, engine: InferenceEngine
) -> Tuple[str, GenContext, List[str]]:
messages = [{"role": m.role, "content": m.content} for m in request.messages]
prompt = engine.tokenizer.apply_chat_template(messages, tokenize=False)
self._resp_id = f"chatcmpl-{uuid.uuid4().hex[:12]}"
self._model = request.model
ctx = GenContext(
resp_id=self._resp_id,
created=0,
model=self._model,
prompt_tokens=0,
)
stop = request.stop
stop_sequences = (
[] if stop is None else [stop] if isinstance(stop, str) else stop
)
return prompt, ctx, stop_sequences
def format_stream_start(self, ctx: GenContext) -> List[str]:
return [
sse_event(
{
"id": self._resp_id,
"object": "chat.completion.chunk",
"created": ctx.created,
"model": self._model,
"choices": [
{
"index": 0,
"delta": {"role": "assistant"},
"finish_reason": None,
}
],
}
)
]
def format_chunk(self, token: str) -> str:
return sse_event(
{
"id": self._resp_id,
"object": "chat.completion.chunk",
"created": 0,
"model": self._model,
"choices": [
{"index": 0, "delta": {"content": token}, "finish_reason": None}
],
}
)
def format_stream_end(self, ctx: GenContext, stop: StopInfo) -> List[str]:
return [
sse_event(
{
"id": self._resp_id,
"object": "chat.completion.chunk",
"created": ctx.created,
"model": self._model,
"choices": [{"index": 0, "delta": {}, "finish_reason": "stop"}],
}
),
sse_event(
{
"prompt_tokens": ctx.prompt_tokens,
"completion_tokens": ctx.completion_tokens,
"total_tokens": ctx.prompt_tokens + ctx.completion_tokens,
}
),
]
def format_response(
self, ctx: GenContext, content: str, stop: StopInfo
) -> Dict[str, Any]:
return {
"id": self._resp_id,
"object": "chat.completion",
"created": ctx.created,
"model": self._model,
"choices": [
{
"index": 0,
"message": {"role": "assistant", "content": content},
"finish_reason": "stop",
}
],
"usage": {
"prompt_tokens": ctx.prompt_tokens,
"completion_tokens": ctx.completion_tokens,
"total_tokens": ctx.prompt_tokens + ctx.completion_tokens,
},
}

View File

@ -1,15 +1,14 @@
"""Protocol handlers for OpenAI and Anthropic chat completion APIs. """Orchestration layer: ProtocolHandler, StopChecker, GenContext, StopInfo, ResponseBuilder, SSE utils.
Template Method + Builder patterns eliminate the 45% code duplication between ProtocolHandler orchestrates the async generation loop and delegates
stream/non-stream branches and across protocol adapters. protocol-specific formatting to a ResponseBuilder.
""" """
import json import json
import time import time
import uuid
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from dataclasses import dataclass from dataclasses import dataclass, field
from typing import Any, Dict, List, Optional, Union from typing import Any, AsyncGenerator, Dict, List, Optional, Tuple, Union
from fastapi.responses import StreamingResponse from fastapi.responses import StreamingResponse
from pydantic import BaseModel from pydantic import BaseModel
@ -17,7 +16,7 @@ from pydantic import BaseModel
from astrai.inference.engine import InferenceEngine from astrai.inference.engine import InferenceEngine
def _sse_event(data: Dict[str, Any], event: Optional[str] = None) -> str: def sse_event(data: Dict[str, Any], event: Optional[str] = None) -> str:
lines: List[str] = [] lines: List[str] = []
if event: if event:
lines.append(f"event: {event}") lines.append(f"event: {event}")
@ -26,22 +25,28 @@ def _sse_event(data: Dict[str, Any], event: Optional[str] = None) -> str:
return "\n".join(lines) return "\n".join(lines)
def _sse_done() -> str: def sse_done() -> str:
return "data: [DONE]\n\n" return "data: [DONE]\n\n"
@dataclass @dataclass
class StreamContext: class GenContext:
"""Shared state across the streaming generation lifecycle.""" """Per-generation metadata passed to builder format methods."""
resp_id: str resp_id: str
created: int created: int
model: str model: str
prompt_tokens: int prompt_tokens: int
completion_tokens: int = 0 completion_tokens: int = 0
accumulated: str = ""
stop_matched: Optional[str] = None
last_yield_trimmed: str = "" @dataclass
class StopInfo:
"""Stop-check result passed to format_stream_end / format_response."""
matched: Optional[str] = None
body: str = ""
yielded: str = ""
class StopChecker: class StopChecker:
@ -56,95 +61,60 @@ class StopChecker:
return seq return seq
return None return None
def trim(self, text: str, matched: str) -> str:
idx = text.rfind(matched)
return text[:idx] if idx != -1 else text
@property class ResponseBuilder(ABC):
def has_sequences(self) -> bool: """Interface for protocol-specific response formatting.
return len(self._sequences) > 0
A new protocol requires one concrete builder implementing 6 methods.
class ProtocolHandler(ABC):
"""Template-method base for API protocol handlers.
Subclasses implement format hooks; the base class orchestrates the
generate-async loop and SSE/JSON response construction.
Lifecycle::
handle()
build_prompt() # protocol-specific prompt assembly
create_response_id() # unique response identifier
[stream]
format_stream_start()
format_stream_token() × N
on_token() hook for stop-sequence interception
format_stream_end()
[non-stream]
(accumulate tokens)
format_non_stream_response()
""" """
request_model: type[BaseModel] @abstractmethod
def prepare(
self, request: BaseModel, engine: InferenceEngine
) -> Tuple[str, GenContext, List[str]]:
"""Return (prompt, ctx, stop_sequences) for a generation request."""
def __init__(self, request: BaseModel, engine: InferenceEngine): @abstractmethod
def format_stream_start(self, ctx: GenContext) -> List[str]:
"""SSE events that open the stream."""
@abstractmethod
def format_chunk(self, token: str) -> str:
"""SSE event for a single generated token."""
@abstractmethod
def format_stream_end(self, ctx: GenContext, stop: StopInfo) -> List[str]:
"""SSE events that close the stream."""
@abstractmethod
def format_response(
self, ctx: GenContext, content: str, stop: StopInfo
) -> Dict[str, Any]:
"""JSON response body for non-streaming mode."""
class ProtocolHandler:
"""Orchestrates the generation loop, delegates formatting to a builder.
Usage::
handler = ProtocolHandler(request, engine, OpenAIResponseBuilder())
response = await handler.handle()
"""
def __init__(
self, request: BaseModel, engine: InferenceEngine, builder: ResponseBuilder
):
self.request = request self.request = request
self.engine = engine self.engine = engine
self.builder = builder
@abstractmethod
def build_prompt(self) -> str:
"""Build the full prompt string from the request messages."""
@abstractmethod
def create_response_id(self) -> str:
"""Generate a unique response ID following the protocol convention."""
@abstractmethod
def format_stream_start(self, ctx: StreamContext) -> List[str]:
"""Yield SSE events that open the stream (role marker, metadata)."""
@abstractmethod
def format_stream_token(self, ctx: StreamContext, token: str) -> str:
"""Yield an SSE event for a single generated token."""
@abstractmethod
def format_stream_end(self, ctx: StreamContext) -> List[str]:
"""Yield SSE events that close the stream (finish reason, usage stats)."""
@abstractmethod
def format_non_stream_response(
self, ctx: StreamContext, content: str
) -> Dict[str, Any]:
"""Build the JSON response body for non-streaming mode."""
def get_stop_sequences(self) -> List[str]:
return []
def create_stop_checker(self) -> StopChecker:
return StopChecker(self.get_stop_sequences())
def on_token(
self, ctx: StreamContext, token: str, stop_checker: StopChecker
) -> Optional[str]:
"""Hook after each token is appended to accumulated.
Return a matched stop-sequence string to break the loop,
or None to continue.
"""
return None
async def handle(self) -> Union[StreamingResponse, Dict[str, Any]]: async def handle(self) -> Union[StreamingResponse, Dict[str, Any]]:
ctx = StreamContext( prompt, ctx, stop_sequences = self.builder.prepare(self.request, self.engine)
resp_id=self.create_response_id(), ctx.prompt_tokens = len(self.engine.tokenizer.encode(prompt))
created=int(time.time()),
model=self.request.model,
prompt_tokens=self._count_prompt_tokens(),
)
agen = self.engine.generate_async( agen = self.engine.generate_async(
prompt=self.build_prompt(), prompt=prompt,
max_tokens=self.request.max_tokens, max_tokens=self.request.max_tokens,
temperature=self.request.temperature, temperature=self.request.temperature,
top_p=self.request.top_p, top_p=self.request.top_p,
@ -152,33 +122,37 @@ class ProtocolHandler(ABC):
) )
if self.request.stream: if self.request.stream:
return self._handle_stream(agen, ctx) return self._handle_stream(agen, ctx, stop_sequences)
else: else:
return await self._handle_non_stream(agen, ctx) return await self._handle_non_stream(agen, ctx, stop_sequences)
def _count_prompt_tokens(self) -> int: def _handle_stream(
return len(self.engine.tokenizer.encode(self.build_prompt())) self, agen: AsyncGenerator, ctx: GenContext, stop_sequences: List[str]
) -> StreamingResponse:
def _handle_stream(self, agen, ctx: StreamContext) -> StreamingResponse: checker = StopChecker(stop_sequences)
stop_checker = self.create_stop_checker()
async def event_stream(): async def event_stream():
for event in self.format_stream_start(ctx): for event in self.builder.format_stream_start(ctx):
yield event yield event
body = ""
yielded = ""
matched = None
async for token in agen: async for token in agen:
ctx.completion_tokens += 1 ctx.completion_tokens += 1
ctx.accumulated += token body += token
matched = self.on_token(ctx, token, stop_checker) matched = checker.check(body)
if matched: if matched:
break break
yield self.format_stream_token(ctx, token) yield self.builder.format_chunk(token)
yielded += token
for event in self.format_stream_end(ctx): stop = StopInfo(matched=matched, body=body, yielded=yielded)
for event in self.builder.format_stream_end(ctx, stop):
yield event yield event
yield _sse_done() yield sse_done()
return StreamingResponse( return StreamingResponse(
event_stream(), event_stream(),
@ -186,260 +160,23 @@ class ProtocolHandler(ABC):
headers={"Cache-Control": "no-cache", "Connection": "keep-alive"}, headers={"Cache-Control": "no-cache", "Connection": "keep-alive"},
) )
async def _handle_non_stream(self, agen, ctx: StreamContext) -> Dict[str, Any]: async def _handle_non_stream(
stop_checker = self.create_stop_checker() self, agen: AsyncGenerator, ctx: GenContext, stop_sequences: List[str]
) -> Dict[str, Any]:
checker = StopChecker(stop_sequences)
chunks: List[str] = [] chunks: List[str] = []
body = ""
matched = None
async for token in agen: async for token in agen:
ctx.completion_tokens += 1 ctx.completion_tokens += 1
ctx.accumulated += token
chunks.append(token) chunks.append(token)
body += token
matched = self.on_token(ctx, token, stop_checker) matched = checker.check(body)
if matched: if matched:
break break
content = "".join(chunks) content = "".join(chunks)
return self.format_non_stream_response(ctx, content) stop = StopInfo(matched=matched, body=body)
return self.builder.format_response(ctx, content, stop)
def _extract_text_content(content: Union[str, List[Dict[str, Any]]]) -> str:
"""Extract plain text from an Anthropic content block (string or list)."""
if isinstance(content, str):
return content
if isinstance(content, list):
for block in content:
if isinstance(block, dict) and block.get("type") == "text":
return block.get("text", "")
return ""
class OpenAIHandler(ProtocolHandler):
"""OpenAI-compatible /v1/chat/completions handler."""
def build_prompt(self) -> str:
messages = [
{"role": m.role, "content": m.content} for m in self.request.messages
]
return self.engine.tokenizer.apply_chat_template(messages, tokenize=False)
def create_response_id(self) -> str:
return f"chatcmpl-{uuid.uuid4().hex[:12]}"
def get_stop_sequences(self) -> List[str]:
stop = self.request.stop
if stop is None:
return []
return [stop] if isinstance(stop, str) else stop
def on_token(
self, ctx: StreamContext, token: str, stop_checker: StopChecker
) -> Optional[str]:
return stop_checker.check(ctx.accumulated)
def format_stream_start(self, ctx: StreamContext) -> List[str]:
return [
_sse_event(
{
"id": ctx.resp_id,
"object": "chat.completion.chunk",
"created": ctx.created,
"model": ctx.model,
"choices": [
{
"index": 0,
"delta": {"role": "assistant"},
"finish_reason": None,
}
],
}
)
]
def format_stream_token(self, ctx: StreamContext, token: str) -> str:
return _sse_event(
{
"id": ctx.resp_id,
"object": "chat.completion.chunk",
"created": ctx.created,
"model": ctx.model,
"choices": [
{"index": 0, "delta": {"content": token}, "finish_reason": None}
],
}
)
def format_stream_end(self, ctx: StreamContext) -> List[str]:
return [
_sse_event(
{
"id": ctx.resp_id,
"object": "chat.completion.chunk",
"created": ctx.created,
"model": ctx.model,
"choices": [{"index": 0, "delta": {}, "finish_reason": "stop"}],
}
),
_sse_event(
{
"prompt_tokens": ctx.prompt_tokens,
"completion_tokens": ctx.completion_tokens,
"total_tokens": ctx.prompt_tokens + ctx.completion_tokens,
}
),
]
def format_non_stream_response(
self, ctx: StreamContext, content: str
) -> Dict[str, Any]:
return {
"id": ctx.resp_id,
"object": "chat.completion",
"created": ctx.created,
"model": ctx.model,
"choices": [
{
"index": 0,
"message": {"role": "assistant", "content": content},
"finish_reason": "stop",
}
],
"usage": {
"prompt_tokens": ctx.prompt_tokens,
"completion_tokens": ctx.completion_tokens,
"total_tokens": ctx.prompt_tokens + ctx.completion_tokens,
},
}
class AnthropicHandler(ProtocolHandler):
"""Anthropic-compatible /v1/messages handler."""
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self._yielded = ""
def build_prompt(self) -> str:
messages: List[Dict[str, str]] = []
system = getattr(self.request, "system", None)
if system:
messages.append({"role": "system", "content": system})
for m in self.request.messages:
content = _extract_text_content(m.content)
if content:
messages.append({"role": m.role, "content": content})
return self.engine.tokenizer.apply_chat_template(messages, tokenize=False)
def create_response_id(self) -> str:
return f"msg_{uuid.uuid4().hex[:24]}"
def get_stop_sequences(self) -> List[str]:
return getattr(self.request, "stop_sequences", None) or []
def on_token(
self, ctx: StreamContext, token: str, stop_checker: StopChecker
) -> Optional[str]:
matched = stop_checker.check(ctx.accumulated)
if not matched:
return None
ctx.stop_matched = matched
trimmed = ctx.accumulated[: ctx.accumulated.rfind(matched)]
unyielded = trimmed[len(self._yielded) :]
if unyielded:
ctx.last_yield_trimmed = unyielded
return matched
def format_stream_start(self, ctx: StreamContext) -> List[str]:
return [
_sse_event(
{
"type": "message_start",
"message": {
"id": ctx.resp_id,
"type": "message",
"role": "assistant",
"model": ctx.model,
"content": [],
"usage": {"input_tokens": ctx.prompt_tokens},
},
},
event="message_start",
),
_sse_event(
{
"type": "content_block_start",
"index": 0,
"content_block": {"type": "text", "text": ""},
},
event="content_block_start",
),
]
def format_stream_token(self, ctx: StreamContext, token: str) -> str:
self._yielded += token
return _sse_event(
{
"type": "content_block_delta",
"index": 0,
"delta": {"type": "text_delta", "text": token},
},
event="content_block_delta",
)
def format_stream_end(self, ctx: StreamContext) -> List[str]:
matched = ctx.stop_matched
events: List[str] = []
last_yielded = ctx.last_yield_trimmed
if last_yielded:
events.append(
_sse_event(
{
"type": "content_block_delta",
"index": 0,
"delta": {"type": "text_delta", "text": last_yielded},
},
event="content_block_delta",
)
)
events.append(
_sse_event(
{"type": "content_block_stop", "index": 0},
event="content_block_stop",
)
)
events.append(
_sse_event(
{
"type": "message_delta",
"delta": {
"stop_reason": "stop_sequence" if matched else "end_turn",
"stop_sequence": matched,
},
"usage": {"output_tokens": ctx.completion_tokens},
},
event="message_delta",
)
)
events.append(_sse_event({"type": "message_stop"}, event="message_stop"))
return events
def format_non_stream_response(
self, ctx: StreamContext, content: str
) -> Dict[str, Any]:
matched = ctx.stop_matched
if matched:
content = content[: content.rfind(matched)]
return {
"id": ctx.resp_id,
"type": "message",
"role": "assistant",
"model": ctx.model,
"content": [{"type": "text", "text": content}],
"stop_reason": "stop_sequence" if matched else "end_turn",
"stop_sequence": matched,
"usage": {
"input_tokens": ctx.prompt_tokens,
"output_tokens": ctx.completion_tokens,
},
}

View File

@ -15,7 +15,9 @@ import uvicorn
from fastapi import FastAPI, HTTPException from fastapi import FastAPI, HTTPException
from pydantic import BaseModel, Field from pydantic import BaseModel, Field
from astrai.inference.api.protocol import AnthropicHandler, OpenAIHandler from astrai.inference.api.anthropic import AnthropicResponseBuilder
from astrai.inference.api.openai import OpenAIResponseBuilder
from astrai.inference.api.protocol import ProtocolHandler
from astrai.inference.engine import InferenceEngine from astrai.inference.engine import InferenceEngine
from astrai.model import AutoModel from astrai.model import AutoModel
from astrai.tokenize import AutoTokenizer from astrai.tokenize import AutoTokenizer
@ -133,14 +135,14 @@ async def get_stats():
@app.post("/v1/chat/completions") @app.post("/v1/chat/completions")
async def chat_completion(request: ChatCompletionRequest): async def chat_completion(request: ChatCompletionRequest):
engine = _get_engine() engine = _get_engine()
handler = OpenAIHandler(request, engine) handler = ProtocolHandler(request, engine, OpenAIResponseBuilder())
return await handler.handle() return await handler.handle()
@app.post("/v1/messages") @app.post("/v1/messages")
async def create_message(request: MessagesRequest): async def create_message(request: MessagesRequest):
engine = _get_engine() engine = _get_engine()
handler = AnthropicHandler(request, engine) handler = ProtocolHandler(request, engine, AnthropicResponseBuilder())
return await handler.handle() return await handler.handle()

View File

@ -108,7 +108,10 @@ class InferenceScheduler:
continue continue
to_prefill = [ to_prefill = [
t for t in self._task_mgr.get_active_tasks() if t.output_tokens == 0 t
for t in self._task_mgr.get_active_tasks()
if t.output_tokens == 0
and self._page_cache.task_cached(t.task_id) < len(t.prompt_ids)
] ]
if to_prefill: if to_prefill:
for t in to_prefill: for t in to_prefill:

View File

@ -13,17 +13,6 @@ from astrai.inference.core.task import STOP
from astrai.tokenize import AutoTokenizer from astrai.tokenize import AutoTokenizer
def _validate_sampling_params(
top_k: int, top_p: float, temperature: float, max_tokens: Optional[int] = None
):
if not (isinstance(top_k, int) and top_k >= 0):
raise ValueError("top_k must be a non-negative integer")
if not (0.0 <= top_p <= 1.0):
raise ValueError("top_p must be a float between 0.0 and 1.0")
if not (isinstance(temperature, (int, float)) and temperature >= 0):
raise ValueError("temperature must be a non-negative number")
class GenerateResult: class GenerateResult:
"""Thread-safe token accumulator for streaming and non-streaming modes.""" """Thread-safe token accumulator for streaming and non-streaming modes."""
@ -86,7 +75,12 @@ class GenerationRequest:
max_tokens: Optional[int] = None, max_tokens: Optional[int] = None,
stream: bool = False, stream: bool = False,
): ):
_validate_sampling_params(top_k, top_p, temperature, max_tokens) if not (isinstance(top_k, int) and top_k >= 0):
raise ValueError("top_k must be a non-negative integer")
if not (0.0 <= top_p <= 1.0):
raise ValueError("top_p must be a float between 0.0 and 1.0")
if not (isinstance(temperature, (int, float)) and temperature >= 0):
raise ValueError("temperature must be a non-negative number")
self.messages = messages self.messages = messages
self.top_k = top_k self.top_k = top_k
@ -137,7 +131,6 @@ class InferenceEngine:
top_p: float = 1.0, top_p: float = 1.0,
top_k: int = 50, top_k: int = 50,
) -> Union[Generator, str, List[str]]: ) -> Union[Generator, str, List[str]]:
_validate_sampling_params(top_k, top_p, temperature, max_tokens)
is_batch = isinstance(prompt, list) is_batch = isinstance(prompt, list)
prompts = prompt if is_batch else [prompt] prompts = prompt if is_batch else [prompt]
@ -158,7 +151,6 @@ class InferenceEngine:
top_p: float = 1.0, top_p: float = 1.0,
top_k: int = 50, top_k: int = 50,
) -> AsyncGenerator[str, None]: ) -> AsyncGenerator[str, None]:
_validate_sampling_params(top_k, top_p, temperature, max_tokens)
sync_gen = self._generate_streaming( sync_gen = self._generate_streaming(
[prompt], False, max_tokens, temperature, top_p, top_k [prompt], False, max_tokens, temperature, top_p, top_k
) )