Compare commits
No commits in common. "65ab69543b4da3afc440a1efd6005bb4cbcfda22" and "737585a32abdbe6064f63df259b207f3e2c5fae5" have entirely different histories.
65ab69543b
...
737585a32a
|
|
@ -43,7 +43,6 @@ class ResumableDistributedSampler(Sampler[int]):
|
||||||
offset = 0 if drop_last else self.num_replicas - 1
|
offset = 0 if drop_last else self.num_replicas - 1
|
||||||
self.num_samples_per_replica = (self.num_samples + offset) // self.num_replicas
|
self.num_samples_per_replica = (self.num_samples + offset) // self.num_replicas
|
||||||
self.total_size = self.num_samples_per_replica * self.num_replicas
|
self.total_size = self.num_samples_per_replica * self.num_replicas
|
||||||
self.iter = self.iter % self.num_samples_per_replica
|
|
||||||
|
|
||||||
self._indices = None
|
self._indices = None
|
||||||
|
|
||||||
|
|
@ -75,10 +74,5 @@ class ResumableDistributedSampler(Sampler[int]):
|
||||||
self.epoch += 1
|
self.epoch += 1
|
||||||
self._indices = None
|
self._indices = None
|
||||||
|
|
||||||
@property
|
|
||||||
def _remaining(self):
|
|
||||||
remaining = self.num_samples_per_replica - self.iter
|
|
||||||
return max(remaining, 0)
|
|
||||||
|
|
||||||
def __len__(self):
|
def __len__(self):
|
||||||
return self._remaining
|
return self.num_samples_per_replica
|
||||||
|
|
|
||||||
|
|
@ -2,26 +2,24 @@
|
||||||
|
|
||||||
Layers:
|
Layers:
|
||||||
- core/: Core inference loop (cache, executor, scheduler, task)
|
- core/: Core inference loop (cache, executor, scheduler, task)
|
||||||
- api/: HTTP orchestration (ProtocolHandler, server)
|
- api/: HTTP protocol handlers (OpenAI, Anthropic)
|
||||||
- protocols/: Response builders (OpenAI, Anthropic)
|
|
||||||
- transport/: SSE transport utilities
|
|
||||||
- engine.py: Facade (InferenceEngine), Value Object (GenerationRequest)
|
- engine.py: Facade (InferenceEngine), Value Object (GenerationRequest)
|
||||||
- sample.py: Strategy pattern (TemperatureStrategy, TopKStrategy, TopPStrategy)
|
- sample.py: Strategy pattern (TemperatureStrategy, TopKStrategy, TopPStrategy)
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from astrai.inference.api import (
|
from astrai.inference.api import (
|
||||||
|
AnthropicHandler,
|
||||||
AnthropicMessage,
|
AnthropicMessage,
|
||||||
ChatCompletionRequest,
|
ChatCompletionRequest,
|
||||||
ChatMessage,
|
ChatMessage,
|
||||||
GenContext,
|
|
||||||
MessagesRequest,
|
MessagesRequest,
|
||||||
|
OpenAIHandler,
|
||||||
ProtocolHandler,
|
ProtocolHandler,
|
||||||
StopChecker,
|
StopChecker,
|
||||||
|
StreamContext,
|
||||||
app,
|
app,
|
||||||
run_server,
|
run_server,
|
||||||
)
|
)
|
||||||
from astrai.inference.api.anthropic import AnthropicResponseBuilder
|
|
||||||
from astrai.inference.api.openai import OpenAIResponseBuilder
|
|
||||||
from astrai.inference.core import (
|
from astrai.inference.core import (
|
||||||
STOP,
|
STOP,
|
||||||
Allocator,
|
Allocator,
|
||||||
|
|
@ -38,7 +36,10 @@ from astrai.inference.core import (
|
||||||
TaskTable,
|
TaskTable,
|
||||||
page_hash,
|
page_hash,
|
||||||
)
|
)
|
||||||
from astrai.inference.engine import GenerationRequest, InferenceEngine
|
from astrai.inference.engine import (
|
||||||
|
GenerationRequest,
|
||||||
|
InferenceEngine,
|
||||||
|
)
|
||||||
from astrai.inference.sample import (
|
from astrai.inference.sample import (
|
||||||
BaseSamplingStrategy,
|
BaseSamplingStrategy,
|
||||||
SamplingPipeline,
|
SamplingPipeline,
|
||||||
|
|
@ -49,14 +50,17 @@ from astrai.inference.sample import (
|
||||||
)
|
)
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
|
# Engine / Requests
|
||||||
"InferenceEngine",
|
"InferenceEngine",
|
||||||
"GenerationRequest",
|
"GenerationRequest",
|
||||||
|
# Core scheduler
|
||||||
"InferenceScheduler",
|
"InferenceScheduler",
|
||||||
"Executor",
|
"Executor",
|
||||||
"STOP",
|
"STOP",
|
||||||
"Task",
|
"Task",
|
||||||
"TaskManager",
|
"TaskManager",
|
||||||
"TaskStatus",
|
"TaskStatus",
|
||||||
|
# Core cache
|
||||||
"Allocator",
|
"Allocator",
|
||||||
"KVCache",
|
"KVCache",
|
||||||
"KvcacheView",
|
"KvcacheView",
|
||||||
|
|
@ -65,17 +69,20 @@ __all__ = [
|
||||||
"Storage",
|
"Storage",
|
||||||
"TaskTable",
|
"TaskTable",
|
||||||
"page_hash",
|
"page_hash",
|
||||||
|
# Sampling (Strategy pattern)
|
||||||
"sample",
|
"sample",
|
||||||
"BaseSamplingStrategy",
|
"BaseSamplingStrategy",
|
||||||
"TemperatureStrategy",
|
"TemperatureStrategy",
|
||||||
"TopKStrategy",
|
"TopKStrategy",
|
||||||
"TopPStrategy",
|
"TopPStrategy",
|
||||||
"SamplingPipeline",
|
"SamplingPipeline",
|
||||||
|
# Protocol
|
||||||
"ProtocolHandler",
|
"ProtocolHandler",
|
||||||
"StopChecker",
|
"StopChecker",
|
||||||
"GenContext",
|
"StreamContext",
|
||||||
"OpenAIResponseBuilder",
|
"AnthropicHandler",
|
||||||
"AnthropicResponseBuilder",
|
"OpenAIHandler",
|
||||||
|
# Server
|
||||||
"ChatMessage",
|
"ChatMessage",
|
||||||
"ChatCompletionRequest",
|
"ChatCompletionRequest",
|
||||||
"AnthropicMessage",
|
"AnthropicMessage",
|
||||||
|
|
|
||||||
|
|
@ -1,6 +1,12 @@
|
||||||
"""Inference API: protocol handler, stop checker, and FastAPI server."""
|
"""Inference API: protocol handlers and FastAPI server."""
|
||||||
|
|
||||||
from astrai.inference.api.protocol import GenContext, ProtocolHandler, StopChecker
|
from astrai.inference.api.protocol import (
|
||||||
|
AnthropicHandler,
|
||||||
|
OpenAIHandler,
|
||||||
|
ProtocolHandler,
|
||||||
|
StopChecker,
|
||||||
|
StreamContext,
|
||||||
|
)
|
||||||
from astrai.inference.api.server import (
|
from astrai.inference.api.server import (
|
||||||
AnthropicMessage,
|
AnthropicMessage,
|
||||||
ChatCompletionRequest,
|
ChatCompletionRequest,
|
||||||
|
|
@ -11,9 +17,11 @@ from astrai.inference.api.server import (
|
||||||
)
|
)
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
|
"AnthropicHandler",
|
||||||
|
"OpenAIHandler",
|
||||||
"ProtocolHandler",
|
"ProtocolHandler",
|
||||||
"StopChecker",
|
"StopChecker",
|
||||||
"GenContext",
|
"StreamContext",
|
||||||
"AnthropicMessage",
|
"AnthropicMessage",
|
||||||
"ChatCompletionRequest",
|
"ChatCompletionRequest",
|
||||||
"ChatMessage",
|
"ChatMessage",
|
||||||
|
|
|
||||||
|
|
@ -1,140 +0,0 @@
|
||||||
"""Anthropic message completion response builder."""
|
|
||||||
|
|
||||||
import uuid
|
|
||||||
from typing import Any, Dict, List, Tuple, Union
|
|
||||||
|
|
||||||
from pydantic import BaseModel
|
|
||||||
|
|
||||||
from astrai.inference.api.protocol import (
|
|
||||||
GenContext,
|
|
||||||
ResponseBuilder,
|
|
||||||
StopInfo,
|
|
||||||
sse_event,
|
|
||||||
)
|
|
||||||
from astrai.inference.engine import InferenceEngine
|
|
||||||
|
|
||||||
|
|
||||||
def _extract_text(content: Union[str, List[Dict[str, Any]]]) -> str:
|
|
||||||
if isinstance(content, str):
|
|
||||||
return content
|
|
||||||
if isinstance(content, list):
|
|
||||||
for block in content:
|
|
||||||
if isinstance(block, dict) and block.get("type") == "text":
|
|
||||||
return block.get("text", "")
|
|
||||||
return ""
|
|
||||||
|
|
||||||
|
|
||||||
class AnthropicResponseBuilder(ResponseBuilder):
|
|
||||||
def prepare(
|
|
||||||
self, request: BaseModel, engine: InferenceEngine
|
|
||||||
) -> Tuple[str, GenContext, List[str]]:
|
|
||||||
messages: List[Dict[str, str]] = []
|
|
||||||
system = getattr(request, "system", None)
|
|
||||||
if system:
|
|
||||||
messages.append({"role": "system", "content": system})
|
|
||||||
for m in request.messages:
|
|
||||||
text = _extract_text(m.content)
|
|
||||||
if text:
|
|
||||||
messages.append({"role": m.role, "content": text})
|
|
||||||
prompt = engine.tokenizer.apply_chat_template(messages, tokenize=False)
|
|
||||||
ctx = GenContext(
|
|
||||||
resp_id=f"msg_{uuid.uuid4().hex[:24]}",
|
|
||||||
created=0,
|
|
||||||
model=request.model,
|
|
||||||
prompt_tokens=0,
|
|
||||||
)
|
|
||||||
stop_sequences = getattr(request, "stop_sequences", None) or []
|
|
||||||
return prompt, ctx, stop_sequences
|
|
||||||
|
|
||||||
def format_stream_start(self, ctx: GenContext) -> List[str]:
|
|
||||||
return [
|
|
||||||
sse_event(
|
|
||||||
{
|
|
||||||
"type": "message_start",
|
|
||||||
"message": {
|
|
||||||
"id": ctx.resp_id,
|
|
||||||
"type": "message",
|
|
||||||
"role": "assistant",
|
|
||||||
"model": ctx.model,
|
|
||||||
"content": [],
|
|
||||||
"usage": {"input_tokens": ctx.prompt_tokens},
|
|
||||||
},
|
|
||||||
},
|
|
||||||
event="message_start",
|
|
||||||
),
|
|
||||||
sse_event(
|
|
||||||
{
|
|
||||||
"type": "content_block_start",
|
|
||||||
"index": 0,
|
|
||||||
"content_block": {"type": "text", "text": ""},
|
|
||||||
},
|
|
||||||
event="content_block_start",
|
|
||||||
),
|
|
||||||
]
|
|
||||||
|
|
||||||
def format_chunk(self, token: str) -> str:
|
|
||||||
return sse_event(
|
|
||||||
{
|
|
||||||
"type": "content_block_delta",
|
|
||||||
"index": 0,
|
|
||||||
"delta": {"type": "text_delta", "text": token},
|
|
||||||
},
|
|
||||||
event="content_block_delta",
|
|
||||||
)
|
|
||||||
|
|
||||||
def format_stream_end(self, ctx: GenContext, stop: StopInfo) -> List[str]:
|
|
||||||
events: List[str] = []
|
|
||||||
if stop.matched:
|
|
||||||
trimmed = stop.body[: stop.body.rfind(stop.matched)]
|
|
||||||
unyielded = trimmed[len(stop.yielded) :]
|
|
||||||
if unyielded:
|
|
||||||
events.append(
|
|
||||||
sse_event(
|
|
||||||
{
|
|
||||||
"type": "content_block_delta",
|
|
||||||
"index": 0,
|
|
||||||
"delta": {"type": "text_delta", "text": unyielded},
|
|
||||||
},
|
|
||||||
event="content_block_delta",
|
|
||||||
)
|
|
||||||
)
|
|
||||||
events.append(
|
|
||||||
sse_event(
|
|
||||||
{"type": "content_block_stop", "index": 0},
|
|
||||||
event="content_block_stop",
|
|
||||||
)
|
|
||||||
)
|
|
||||||
events.append(
|
|
||||||
sse_event(
|
|
||||||
{
|
|
||||||
"type": "message_delta",
|
|
||||||
"delta": {
|
|
||||||
"stop_reason": "stop_sequence" if stop.matched else "end_turn",
|
|
||||||
"stop_sequence": stop.matched,
|
|
||||||
},
|
|
||||||
"usage": {"output_tokens": ctx.completion_tokens},
|
|
||||||
},
|
|
||||||
event="message_delta",
|
|
||||||
)
|
|
||||||
)
|
|
||||||
events.append(sse_event({"type": "message_stop"}, event="message_stop"))
|
|
||||||
return events
|
|
||||||
|
|
||||||
def format_response(
|
|
||||||
self, ctx: GenContext, content: str, stop: StopInfo
|
|
||||||
) -> Dict[str, Any]:
|
|
||||||
if stop.matched:
|
|
||||||
content = content[: content.rfind(stop.matched)]
|
|
||||||
return {
|
|
||||||
"id": ctx.resp_id,
|
|
||||||
"type": "message",
|
|
||||||
"role": "assistant",
|
|
||||||
"model": ctx.model,
|
|
||||||
"content": [{"type": "text", "text": content}],
|
|
||||||
"stop_reason": "stop_sequence" if stop.matched else "end_turn",
|
|
||||||
"stop_sequence": stop.matched,
|
|
||||||
"usage": {
|
|
||||||
"input_tokens": ctx.prompt_tokens,
|
|
||||||
"output_tokens": ctx.completion_tokens,
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
@ -1,111 +0,0 @@
|
||||||
"""OpenAI chat completion response builder."""
|
|
||||||
|
|
||||||
import uuid
|
|
||||||
from typing import Any, Dict, List, Tuple
|
|
||||||
|
|
||||||
from pydantic import BaseModel
|
|
||||||
|
|
||||||
from astrai.inference.api.protocol import (
|
|
||||||
GenContext,
|
|
||||||
ResponseBuilder,
|
|
||||||
StopInfo,
|
|
||||||
sse_event,
|
|
||||||
)
|
|
||||||
from astrai.inference.engine import InferenceEngine
|
|
||||||
|
|
||||||
|
|
||||||
class OpenAIResponseBuilder(ResponseBuilder):
|
|
||||||
def prepare(
|
|
||||||
self, request: BaseModel, engine: InferenceEngine
|
|
||||||
) -> Tuple[str, GenContext, List[str]]:
|
|
||||||
messages = [{"role": m.role, "content": m.content} for m in request.messages]
|
|
||||||
prompt = engine.tokenizer.apply_chat_template(messages, tokenize=False)
|
|
||||||
|
|
||||||
self._resp_id = f"chatcmpl-{uuid.uuid4().hex[:12]}"
|
|
||||||
self._model = request.model
|
|
||||||
|
|
||||||
ctx = GenContext(
|
|
||||||
resp_id=self._resp_id,
|
|
||||||
created=0,
|
|
||||||
model=self._model,
|
|
||||||
prompt_tokens=0,
|
|
||||||
)
|
|
||||||
stop = request.stop
|
|
||||||
stop_sequences = (
|
|
||||||
[] if stop is None else [stop] if isinstance(stop, str) else stop
|
|
||||||
)
|
|
||||||
return prompt, ctx, stop_sequences
|
|
||||||
|
|
||||||
def format_stream_start(self, ctx: GenContext) -> List[str]:
|
|
||||||
return [
|
|
||||||
sse_event(
|
|
||||||
{
|
|
||||||
"id": self._resp_id,
|
|
||||||
"object": "chat.completion.chunk",
|
|
||||||
"created": ctx.created,
|
|
||||||
"model": self._model,
|
|
||||||
"choices": [
|
|
||||||
{
|
|
||||||
"index": 0,
|
|
||||||
"delta": {"role": "assistant"},
|
|
||||||
"finish_reason": None,
|
|
||||||
}
|
|
||||||
],
|
|
||||||
}
|
|
||||||
)
|
|
||||||
]
|
|
||||||
|
|
||||||
def format_chunk(self, token: str) -> str:
|
|
||||||
return sse_event(
|
|
||||||
{
|
|
||||||
"id": self._resp_id,
|
|
||||||
"object": "chat.completion.chunk",
|
|
||||||
"created": 0,
|
|
||||||
"model": self._model,
|
|
||||||
"choices": [
|
|
||||||
{"index": 0, "delta": {"content": token}, "finish_reason": None}
|
|
||||||
],
|
|
||||||
}
|
|
||||||
)
|
|
||||||
|
|
||||||
def format_stream_end(self, ctx: GenContext, stop: StopInfo) -> List[str]:
|
|
||||||
return [
|
|
||||||
sse_event(
|
|
||||||
{
|
|
||||||
"id": self._resp_id,
|
|
||||||
"object": "chat.completion.chunk",
|
|
||||||
"created": ctx.created,
|
|
||||||
"model": self._model,
|
|
||||||
"choices": [{"index": 0, "delta": {}, "finish_reason": "stop"}],
|
|
||||||
}
|
|
||||||
),
|
|
||||||
sse_event(
|
|
||||||
{
|
|
||||||
"prompt_tokens": ctx.prompt_tokens,
|
|
||||||
"completion_tokens": ctx.completion_tokens,
|
|
||||||
"total_tokens": ctx.prompt_tokens + ctx.completion_tokens,
|
|
||||||
}
|
|
||||||
),
|
|
||||||
]
|
|
||||||
|
|
||||||
def format_response(
|
|
||||||
self, ctx: GenContext, content: str, stop: StopInfo
|
|
||||||
) -> Dict[str, Any]:
|
|
||||||
return {
|
|
||||||
"id": self._resp_id,
|
|
||||||
"object": "chat.completion",
|
|
||||||
"created": ctx.created,
|
|
||||||
"model": self._model,
|
|
||||||
"choices": [
|
|
||||||
{
|
|
||||||
"index": 0,
|
|
||||||
"message": {"role": "assistant", "content": content},
|
|
||||||
"finish_reason": "stop",
|
|
||||||
}
|
|
||||||
],
|
|
||||||
"usage": {
|
|
||||||
"prompt_tokens": ctx.prompt_tokens,
|
|
||||||
"completion_tokens": ctx.completion_tokens,
|
|
||||||
"total_tokens": ctx.prompt_tokens + ctx.completion_tokens,
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
@ -1,13 +1,15 @@
|
||||||
"""Orchestration layer: ProtocolHandler, StopChecker, GenContext, StopInfo, ResponseBuilder, SSE utils.
|
"""Protocol handlers for OpenAI and Anthropic chat completion APIs.
|
||||||
|
|
||||||
ProtocolHandler orchestrates the async generation loop and delegates
|
Template Method + Builder patterns eliminate the 45% code duplication between
|
||||||
protocol-specific formatting to a ResponseBuilder.
|
stream/non-stream branches and across protocol adapters.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import json
|
import json
|
||||||
|
import time
|
||||||
|
import uuid
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import Any, AsyncGenerator, Dict, List, Optional, Tuple, Union
|
from typing import Any, Dict, List, Optional, Union
|
||||||
|
|
||||||
from fastapi.responses import StreamingResponse
|
from fastapi.responses import StreamingResponse
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
|
|
@ -15,7 +17,7 @@ from pydantic import BaseModel
|
||||||
from astrai.inference.engine import InferenceEngine
|
from astrai.inference.engine import InferenceEngine
|
||||||
|
|
||||||
|
|
||||||
def sse_event(data: Dict[str, Any], event: Optional[str] = None) -> str:
|
def _sse_event(data: Dict[str, Any], event: Optional[str] = None) -> str:
|
||||||
lines: List[str] = []
|
lines: List[str] = []
|
||||||
if event:
|
if event:
|
||||||
lines.append(f"event: {event}")
|
lines.append(f"event: {event}")
|
||||||
|
|
@ -24,28 +26,22 @@ def sse_event(data: Dict[str, Any], event: Optional[str] = None) -> str:
|
||||||
return "\n".join(lines)
|
return "\n".join(lines)
|
||||||
|
|
||||||
|
|
||||||
def sse_done() -> str:
|
def _sse_done() -> str:
|
||||||
return "data: [DONE]\n\n"
|
return "data: [DONE]\n\n"
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class GenContext:
|
class StreamContext:
|
||||||
"""Per-generation metadata passed to builder format methods."""
|
"""Shared state across the streaming generation lifecycle."""
|
||||||
|
|
||||||
resp_id: str
|
resp_id: str
|
||||||
created: int
|
created: int
|
||||||
model: str
|
model: str
|
||||||
prompt_tokens: int
|
prompt_tokens: int
|
||||||
completion_tokens: int = 0
|
completion_tokens: int = 0
|
||||||
|
accumulated: str = ""
|
||||||
|
stop_matched: Optional[str] = None
|
||||||
@dataclass
|
last_yield_trimmed: str = ""
|
||||||
class StopInfo:
|
|
||||||
"""Stop-check result passed to format_stream_end / format_response."""
|
|
||||||
|
|
||||||
matched: Optional[str] = None
|
|
||||||
body: str = ""
|
|
||||||
yielded: str = ""
|
|
||||||
|
|
||||||
|
|
||||||
class StopChecker:
|
class StopChecker:
|
||||||
|
|
@ -60,60 +56,95 @@ class StopChecker:
|
||||||
return seq
|
return seq
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
def trim(self, text: str, matched: str) -> str:
|
||||||
|
idx = text.rfind(matched)
|
||||||
|
return text[:idx] if idx != -1 else text
|
||||||
|
|
||||||
class ResponseBuilder(ABC):
|
@property
|
||||||
"""Interface for protocol-specific response formatting.
|
def has_sequences(self) -> bool:
|
||||||
|
return len(self._sequences) > 0
|
||||||
|
|
||||||
A new protocol requires one concrete builder implementing 6 methods.
|
|
||||||
|
class ProtocolHandler(ABC):
|
||||||
|
"""Template-method base for API protocol handlers.
|
||||||
|
|
||||||
|
Subclasses implement format hooks; the base class orchestrates the
|
||||||
|
generate-async loop and SSE/JSON response construction.
|
||||||
|
|
||||||
|
Lifecycle::
|
||||||
|
|
||||||
|
handle()
|
||||||
|
├─ build_prompt() # protocol-specific prompt assembly
|
||||||
|
├─ create_response_id() # unique response identifier
|
||||||
|
├─ [stream]
|
||||||
|
│ ├─ format_stream_start()
|
||||||
|
│ ├─ format_stream_token() × N
|
||||||
|
│ │ └─ on_token() hook for stop-sequence interception
|
||||||
|
│ └─ format_stream_end()
|
||||||
|
└─ [non-stream]
|
||||||
|
├─ (accumulate tokens)
|
||||||
|
└─ format_non_stream_response()
|
||||||
"""
|
"""
|
||||||
|
|
||||||
@abstractmethod
|
request_model: type[BaseModel]
|
||||||
def prepare(
|
|
||||||
self, request: BaseModel, engine: InferenceEngine
|
|
||||||
) -> Tuple[str, GenContext, List[str]]:
|
|
||||||
"""Return (prompt, ctx, stop_sequences) for a generation request."""
|
|
||||||
|
|
||||||
@abstractmethod
|
def __init__(self, request: BaseModel, engine: InferenceEngine):
|
||||||
def format_stream_start(self, ctx: GenContext) -> List[str]:
|
|
||||||
"""SSE events that open the stream."""
|
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
def format_chunk(self, token: str) -> str:
|
|
||||||
"""SSE event for a single generated token."""
|
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
def format_stream_end(self, ctx: GenContext, stop: StopInfo) -> List[str]:
|
|
||||||
"""SSE events that close the stream."""
|
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
def format_response(
|
|
||||||
self, ctx: GenContext, content: str, stop: StopInfo
|
|
||||||
) -> Dict[str, Any]:
|
|
||||||
"""JSON response body for non-streaming mode."""
|
|
||||||
|
|
||||||
|
|
||||||
class ProtocolHandler:
|
|
||||||
"""Orchestrates the generation loop, delegates formatting to a builder.
|
|
||||||
|
|
||||||
Usage::
|
|
||||||
|
|
||||||
handler = ProtocolHandler(request, engine, OpenAIResponseBuilder())
|
|
||||||
response = await handler.handle()
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self, request: BaseModel, engine: InferenceEngine, builder: ResponseBuilder
|
|
||||||
):
|
|
||||||
self.request = request
|
self.request = request
|
||||||
self.engine = engine
|
self.engine = engine
|
||||||
self.builder = builder
|
|
||||||
|
@abstractmethod
|
||||||
|
def build_prompt(self) -> str:
|
||||||
|
"""Build the full prompt string from the request messages."""
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def create_response_id(self) -> str:
|
||||||
|
"""Generate a unique response ID following the protocol convention."""
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def format_stream_start(self, ctx: StreamContext) -> List[str]:
|
||||||
|
"""Yield SSE events that open the stream (role marker, metadata)."""
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def format_stream_token(self, ctx: StreamContext, token: str) -> str:
|
||||||
|
"""Yield an SSE event for a single generated token."""
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def format_stream_end(self, ctx: StreamContext) -> List[str]:
|
||||||
|
"""Yield SSE events that close the stream (finish reason, usage stats)."""
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def format_non_stream_response(
|
||||||
|
self, ctx: StreamContext, content: str
|
||||||
|
) -> Dict[str, Any]:
|
||||||
|
"""Build the JSON response body for non-streaming mode."""
|
||||||
|
|
||||||
|
def get_stop_sequences(self) -> List[str]:
|
||||||
|
return []
|
||||||
|
|
||||||
|
def create_stop_checker(self) -> StopChecker:
|
||||||
|
return StopChecker(self.get_stop_sequences())
|
||||||
|
|
||||||
|
def on_token(
|
||||||
|
self, ctx: StreamContext, token: str, stop_checker: StopChecker
|
||||||
|
) -> Optional[str]:
|
||||||
|
"""Hook after each token is appended to accumulated.
|
||||||
|
|
||||||
|
Return a matched stop-sequence string to break the loop,
|
||||||
|
or None to continue.
|
||||||
|
|
||||||
|
"""
|
||||||
|
return None
|
||||||
|
|
||||||
async def handle(self) -> Union[StreamingResponse, Dict[str, Any]]:
|
async def handle(self) -> Union[StreamingResponse, Dict[str, Any]]:
|
||||||
prompt, ctx, stop_sequences = self.builder.prepare(self.request, self.engine)
|
ctx = StreamContext(
|
||||||
ctx.prompt_tokens = len(self.engine.tokenizer.encode(prompt))
|
resp_id=self.create_response_id(),
|
||||||
|
created=int(time.time()),
|
||||||
|
model=self.request.model,
|
||||||
|
prompt_tokens=self._count_prompt_tokens(),
|
||||||
|
)
|
||||||
|
|
||||||
agen = self.engine.generate_async(
|
agen = self.engine.generate_async(
|
||||||
prompt=prompt,
|
prompt=self.build_prompt(),
|
||||||
max_tokens=self.request.max_tokens,
|
max_tokens=self.request.max_tokens,
|
||||||
temperature=self.request.temperature,
|
temperature=self.request.temperature,
|
||||||
top_p=self.request.top_p,
|
top_p=self.request.top_p,
|
||||||
|
|
@ -121,37 +152,33 @@ class ProtocolHandler:
|
||||||
)
|
)
|
||||||
|
|
||||||
if self.request.stream:
|
if self.request.stream:
|
||||||
return self._handle_stream(agen, ctx, stop_sequences)
|
return self._handle_stream(agen, ctx)
|
||||||
else:
|
else:
|
||||||
return await self._handle_non_stream(agen, ctx, stop_sequences)
|
return await self._handle_non_stream(agen, ctx)
|
||||||
|
|
||||||
def _handle_stream(
|
def _count_prompt_tokens(self) -> int:
|
||||||
self, agen: AsyncGenerator, ctx: GenContext, stop_sequences: List[str]
|
return len(self.engine.tokenizer.encode(self.build_prompt()))
|
||||||
) -> StreamingResponse:
|
|
||||||
checker = StopChecker(stop_sequences)
|
def _handle_stream(self, agen, ctx: StreamContext) -> StreamingResponse:
|
||||||
|
stop_checker = self.create_stop_checker()
|
||||||
|
|
||||||
async def event_stream():
|
async def event_stream():
|
||||||
for event in self.builder.format_stream_start(ctx):
|
for event in self.format_stream_start(ctx):
|
||||||
yield event
|
yield event
|
||||||
|
|
||||||
body = ""
|
|
||||||
yielded = ""
|
|
||||||
matched = None
|
|
||||||
async for token in agen:
|
async for token in agen:
|
||||||
ctx.completion_tokens += 1
|
ctx.completion_tokens += 1
|
||||||
body += token
|
ctx.accumulated += token
|
||||||
|
|
||||||
matched = checker.check(body)
|
matched = self.on_token(ctx, token, stop_checker)
|
||||||
if matched:
|
if matched:
|
||||||
break
|
break
|
||||||
|
|
||||||
yield self.builder.format_chunk(token)
|
yield self.format_stream_token(ctx, token)
|
||||||
yielded += token
|
|
||||||
|
|
||||||
stop = StopInfo(matched=matched, body=body, yielded=yielded)
|
for event in self.format_stream_end(ctx):
|
||||||
for event in self.builder.format_stream_end(ctx, stop):
|
|
||||||
yield event
|
yield event
|
||||||
yield sse_done()
|
yield _sse_done()
|
||||||
|
|
||||||
return StreamingResponse(
|
return StreamingResponse(
|
||||||
event_stream(),
|
event_stream(),
|
||||||
|
|
@ -159,23 +186,260 @@ class ProtocolHandler:
|
||||||
headers={"Cache-Control": "no-cache", "Connection": "keep-alive"},
|
headers={"Cache-Control": "no-cache", "Connection": "keep-alive"},
|
||||||
)
|
)
|
||||||
|
|
||||||
async def _handle_non_stream(
|
async def _handle_non_stream(self, agen, ctx: StreamContext) -> Dict[str, Any]:
|
||||||
self, agen: AsyncGenerator, ctx: GenContext, stop_sequences: List[str]
|
stop_checker = self.create_stop_checker()
|
||||||
) -> Dict[str, Any]:
|
|
||||||
checker = StopChecker(stop_sequences)
|
|
||||||
chunks: List[str] = []
|
chunks: List[str] = []
|
||||||
body = ""
|
|
||||||
matched = None
|
|
||||||
|
|
||||||
async for token in agen:
|
async for token in agen:
|
||||||
ctx.completion_tokens += 1
|
ctx.completion_tokens += 1
|
||||||
|
ctx.accumulated += token
|
||||||
chunks.append(token)
|
chunks.append(token)
|
||||||
body += token
|
|
||||||
|
|
||||||
matched = checker.check(body)
|
matched = self.on_token(ctx, token, stop_checker)
|
||||||
if matched:
|
if matched:
|
||||||
break
|
break
|
||||||
|
|
||||||
content = "".join(chunks)
|
content = "".join(chunks)
|
||||||
stop = StopInfo(matched=matched, body=body)
|
return self.format_non_stream_response(ctx, content)
|
||||||
return self.builder.format_response(ctx, content, stop)
|
|
||||||
|
|
||||||
|
def _extract_text_content(content: Union[str, List[Dict[str, Any]]]) -> str:
|
||||||
|
"""Extract plain text from an Anthropic content block (string or list)."""
|
||||||
|
if isinstance(content, str):
|
||||||
|
return content
|
||||||
|
if isinstance(content, list):
|
||||||
|
for block in content:
|
||||||
|
if isinstance(block, dict) and block.get("type") == "text":
|
||||||
|
return block.get("text", "")
|
||||||
|
return ""
|
||||||
|
|
||||||
|
|
||||||
|
class OpenAIHandler(ProtocolHandler):
|
||||||
|
"""OpenAI-compatible /v1/chat/completions handler."""
|
||||||
|
|
||||||
|
def build_prompt(self) -> str:
|
||||||
|
messages = [
|
||||||
|
{"role": m.role, "content": m.content} for m in self.request.messages
|
||||||
|
]
|
||||||
|
return self.engine.tokenizer.apply_chat_template(messages, tokenize=False)
|
||||||
|
|
||||||
|
def create_response_id(self) -> str:
|
||||||
|
return f"chatcmpl-{uuid.uuid4().hex[:12]}"
|
||||||
|
|
||||||
|
def get_stop_sequences(self) -> List[str]:
|
||||||
|
stop = self.request.stop
|
||||||
|
if stop is None:
|
||||||
|
return []
|
||||||
|
return [stop] if isinstance(stop, str) else stop
|
||||||
|
|
||||||
|
def on_token(
|
||||||
|
self, ctx: StreamContext, token: str, stop_checker: StopChecker
|
||||||
|
) -> Optional[str]:
|
||||||
|
return stop_checker.check(ctx.accumulated)
|
||||||
|
|
||||||
|
def format_stream_start(self, ctx: StreamContext) -> List[str]:
|
||||||
|
return [
|
||||||
|
_sse_event(
|
||||||
|
{
|
||||||
|
"id": ctx.resp_id,
|
||||||
|
"object": "chat.completion.chunk",
|
||||||
|
"created": ctx.created,
|
||||||
|
"model": ctx.model,
|
||||||
|
"choices": [
|
||||||
|
{
|
||||||
|
"index": 0,
|
||||||
|
"delta": {"role": "assistant"},
|
||||||
|
"finish_reason": None,
|
||||||
|
}
|
||||||
|
],
|
||||||
|
}
|
||||||
|
)
|
||||||
|
]
|
||||||
|
|
||||||
|
def format_stream_token(self, ctx: StreamContext, token: str) -> str:
|
||||||
|
return _sse_event(
|
||||||
|
{
|
||||||
|
"id": ctx.resp_id,
|
||||||
|
"object": "chat.completion.chunk",
|
||||||
|
"created": ctx.created,
|
||||||
|
"model": ctx.model,
|
||||||
|
"choices": [
|
||||||
|
{"index": 0, "delta": {"content": token}, "finish_reason": None}
|
||||||
|
],
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
def format_stream_end(self, ctx: StreamContext) -> List[str]:
|
||||||
|
return [
|
||||||
|
_sse_event(
|
||||||
|
{
|
||||||
|
"id": ctx.resp_id,
|
||||||
|
"object": "chat.completion.chunk",
|
||||||
|
"created": ctx.created,
|
||||||
|
"model": ctx.model,
|
||||||
|
"choices": [{"index": 0, "delta": {}, "finish_reason": "stop"}],
|
||||||
|
}
|
||||||
|
),
|
||||||
|
_sse_event(
|
||||||
|
{
|
||||||
|
"prompt_tokens": ctx.prompt_tokens,
|
||||||
|
"completion_tokens": ctx.completion_tokens,
|
||||||
|
"total_tokens": ctx.prompt_tokens + ctx.completion_tokens,
|
||||||
|
}
|
||||||
|
),
|
||||||
|
]
|
||||||
|
|
||||||
|
def format_non_stream_response(
|
||||||
|
self, ctx: StreamContext, content: str
|
||||||
|
) -> Dict[str, Any]:
|
||||||
|
return {
|
||||||
|
"id": ctx.resp_id,
|
||||||
|
"object": "chat.completion",
|
||||||
|
"created": ctx.created,
|
||||||
|
"model": ctx.model,
|
||||||
|
"choices": [
|
||||||
|
{
|
||||||
|
"index": 0,
|
||||||
|
"message": {"role": "assistant", "content": content},
|
||||||
|
"finish_reason": "stop",
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"usage": {
|
||||||
|
"prompt_tokens": ctx.prompt_tokens,
|
||||||
|
"completion_tokens": ctx.completion_tokens,
|
||||||
|
"total_tokens": ctx.prompt_tokens + ctx.completion_tokens,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
class AnthropicHandler(ProtocolHandler):
|
||||||
|
"""Anthropic-compatible /v1/messages handler."""
|
||||||
|
|
||||||
|
def __init__(self, *args, **kwargs):
|
||||||
|
super().__init__(*args, **kwargs)
|
||||||
|
self._yielded = ""
|
||||||
|
|
||||||
|
def build_prompt(self) -> str:
|
||||||
|
messages: List[Dict[str, str]] = []
|
||||||
|
system = getattr(self.request, "system", None)
|
||||||
|
if system:
|
||||||
|
messages.append({"role": "system", "content": system})
|
||||||
|
for m in self.request.messages:
|
||||||
|
content = _extract_text_content(m.content)
|
||||||
|
if content:
|
||||||
|
messages.append({"role": m.role, "content": content})
|
||||||
|
return self.engine.tokenizer.apply_chat_template(messages, tokenize=False)
|
||||||
|
|
||||||
|
def create_response_id(self) -> str:
|
||||||
|
return f"msg_{uuid.uuid4().hex[:24]}"
|
||||||
|
|
||||||
|
def get_stop_sequences(self) -> List[str]:
|
||||||
|
return getattr(self.request, "stop_sequences", None) or []
|
||||||
|
|
||||||
|
def on_token(
|
||||||
|
self, ctx: StreamContext, token: str, stop_checker: StopChecker
|
||||||
|
) -> Optional[str]:
|
||||||
|
matched = stop_checker.check(ctx.accumulated)
|
||||||
|
if not matched:
|
||||||
|
return None
|
||||||
|
|
||||||
|
ctx.stop_matched = matched
|
||||||
|
trimmed = ctx.accumulated[: ctx.accumulated.rfind(matched)]
|
||||||
|
unyielded = trimmed[len(self._yielded) :]
|
||||||
|
if unyielded:
|
||||||
|
ctx.last_yield_trimmed = unyielded
|
||||||
|
return matched
|
||||||
|
|
||||||
|
def format_stream_start(self, ctx: StreamContext) -> List[str]:
|
||||||
|
return [
|
||||||
|
_sse_event(
|
||||||
|
{
|
||||||
|
"type": "message_start",
|
||||||
|
"message": {
|
||||||
|
"id": ctx.resp_id,
|
||||||
|
"type": "message",
|
||||||
|
"role": "assistant",
|
||||||
|
"model": ctx.model,
|
||||||
|
"content": [],
|
||||||
|
"usage": {"input_tokens": ctx.prompt_tokens},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
event="message_start",
|
||||||
|
),
|
||||||
|
_sse_event(
|
||||||
|
{
|
||||||
|
"type": "content_block_start",
|
||||||
|
"index": 0,
|
||||||
|
"content_block": {"type": "text", "text": ""},
|
||||||
|
},
|
||||||
|
event="content_block_start",
|
||||||
|
),
|
||||||
|
]
|
||||||
|
|
||||||
|
def format_stream_token(self, ctx: StreamContext, token: str) -> str:
|
||||||
|
self._yielded += token
|
||||||
|
return _sse_event(
|
||||||
|
{
|
||||||
|
"type": "content_block_delta",
|
||||||
|
"index": 0,
|
||||||
|
"delta": {"type": "text_delta", "text": token},
|
||||||
|
},
|
||||||
|
event="content_block_delta",
|
||||||
|
)
|
||||||
|
|
||||||
|
def format_stream_end(self, ctx: StreamContext) -> List[str]:
|
||||||
|
matched = ctx.stop_matched
|
||||||
|
events: List[str] = []
|
||||||
|
last_yielded = ctx.last_yield_trimmed
|
||||||
|
if last_yielded:
|
||||||
|
events.append(
|
||||||
|
_sse_event(
|
||||||
|
{
|
||||||
|
"type": "content_block_delta",
|
||||||
|
"index": 0,
|
||||||
|
"delta": {"type": "text_delta", "text": last_yielded},
|
||||||
|
},
|
||||||
|
event="content_block_delta",
|
||||||
|
)
|
||||||
|
)
|
||||||
|
events.append(
|
||||||
|
_sse_event(
|
||||||
|
{"type": "content_block_stop", "index": 0},
|
||||||
|
event="content_block_stop",
|
||||||
|
)
|
||||||
|
)
|
||||||
|
events.append(
|
||||||
|
_sse_event(
|
||||||
|
{
|
||||||
|
"type": "message_delta",
|
||||||
|
"delta": {
|
||||||
|
"stop_reason": "stop_sequence" if matched else "end_turn",
|
||||||
|
"stop_sequence": matched,
|
||||||
|
},
|
||||||
|
"usage": {"output_tokens": ctx.completion_tokens},
|
||||||
|
},
|
||||||
|
event="message_delta",
|
||||||
|
)
|
||||||
|
)
|
||||||
|
events.append(_sse_event({"type": "message_stop"}, event="message_stop"))
|
||||||
|
return events
|
||||||
|
|
||||||
|
def format_non_stream_response(
|
||||||
|
self, ctx: StreamContext, content: str
|
||||||
|
) -> Dict[str, Any]:
|
||||||
|
matched = ctx.stop_matched
|
||||||
|
if matched:
|
||||||
|
content = content[: content.rfind(matched)]
|
||||||
|
return {
|
||||||
|
"id": ctx.resp_id,
|
||||||
|
"type": "message",
|
||||||
|
"role": "assistant",
|
||||||
|
"model": ctx.model,
|
||||||
|
"content": [{"type": "text", "text": content}],
|
||||||
|
"stop_reason": "stop_sequence" if matched else "end_turn",
|
||||||
|
"stop_sequence": matched,
|
||||||
|
"usage": {
|
||||||
|
"input_tokens": ctx.prompt_tokens,
|
||||||
|
"output_tokens": ctx.completion_tokens,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
|
||||||
|
|
@ -15,9 +15,7 @@ import uvicorn
|
||||||
from fastapi import FastAPI, HTTPException
|
from fastapi import FastAPI, HTTPException
|
||||||
from pydantic import BaseModel, Field
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
from astrai.inference.api.anthropic import AnthropicResponseBuilder
|
from astrai.inference.api.protocol import AnthropicHandler, OpenAIHandler
|
||||||
from astrai.inference.api.openai import OpenAIResponseBuilder
|
|
||||||
from astrai.inference.api.protocol import ProtocolHandler
|
|
||||||
from astrai.inference.engine import InferenceEngine
|
from astrai.inference.engine import InferenceEngine
|
||||||
from astrai.model import AutoModel
|
from astrai.model import AutoModel
|
||||||
from astrai.tokenize import AutoTokenizer
|
from astrai.tokenize import AutoTokenizer
|
||||||
|
|
@ -135,14 +133,14 @@ async def get_stats():
|
||||||
@app.post("/v1/chat/completions")
|
@app.post("/v1/chat/completions")
|
||||||
async def chat_completion(request: ChatCompletionRequest):
|
async def chat_completion(request: ChatCompletionRequest):
|
||||||
engine = _get_engine()
|
engine = _get_engine()
|
||||||
handler = ProtocolHandler(request, engine, OpenAIResponseBuilder())
|
handler = OpenAIHandler(request, engine)
|
||||||
return await handler.handle()
|
return await handler.handle()
|
||||||
|
|
||||||
|
|
||||||
@app.post("/v1/messages")
|
@app.post("/v1/messages")
|
||||||
async def create_message(request: MessagesRequest):
|
async def create_message(request: MessagesRequest):
|
||||||
engine = _get_engine()
|
engine = _get_engine()
|
||||||
handler = ProtocolHandler(request, engine, AnthropicResponseBuilder())
|
handler = AnthropicHandler(request, engine)
|
||||||
return await handler.handle()
|
return await handler.handle()
|
||||||
|
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -108,10 +108,7 @@ class InferenceScheduler:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
to_prefill = [
|
to_prefill = [
|
||||||
t
|
t for t in self._task_mgr.get_active_tasks() if t.output_tokens == 0
|
||||||
for t in self._task_mgr.get_active_tasks()
|
|
||||||
if t.output_tokens == 0
|
|
||||||
and self._page_cache.task_cached(t.task_id) < len(t.prompt_ids)
|
|
||||||
]
|
]
|
||||||
if to_prefill:
|
if to_prefill:
|
||||||
for t in to_prefill:
|
for t in to_prefill:
|
||||||
|
|
|
||||||
|
|
@ -13,6 +13,17 @@ from astrai.inference.core.task import STOP
|
||||||
from astrai.tokenize import AutoTokenizer
|
from astrai.tokenize import AutoTokenizer
|
||||||
|
|
||||||
|
|
||||||
|
def _validate_sampling_params(
|
||||||
|
top_k: int, top_p: float, temperature: float, max_tokens: Optional[int] = None
|
||||||
|
):
|
||||||
|
if not (isinstance(top_k, int) and top_k >= 0):
|
||||||
|
raise ValueError("top_k must be a non-negative integer")
|
||||||
|
if not (0.0 <= top_p <= 1.0):
|
||||||
|
raise ValueError("top_p must be a float between 0.0 and 1.0")
|
||||||
|
if not (isinstance(temperature, (int, float)) and temperature >= 0):
|
||||||
|
raise ValueError("temperature must be a non-negative number")
|
||||||
|
|
||||||
|
|
||||||
class GenerateResult:
|
class GenerateResult:
|
||||||
"""Thread-safe token accumulator for streaming and non-streaming modes."""
|
"""Thread-safe token accumulator for streaming and non-streaming modes."""
|
||||||
|
|
||||||
|
|
@ -75,12 +86,7 @@ class GenerationRequest:
|
||||||
max_tokens: Optional[int] = None,
|
max_tokens: Optional[int] = None,
|
||||||
stream: bool = False,
|
stream: bool = False,
|
||||||
):
|
):
|
||||||
if not (isinstance(top_k, int) and top_k >= 0):
|
_validate_sampling_params(top_k, top_p, temperature, max_tokens)
|
||||||
raise ValueError("top_k must be a non-negative integer")
|
|
||||||
if not (0.0 <= top_p <= 1.0):
|
|
||||||
raise ValueError("top_p must be a float between 0.0 and 1.0")
|
|
||||||
if not (isinstance(temperature, (int, float)) and temperature >= 0):
|
|
||||||
raise ValueError("temperature must be a non-negative number")
|
|
||||||
|
|
||||||
self.messages = messages
|
self.messages = messages
|
||||||
self.top_k = top_k
|
self.top_k = top_k
|
||||||
|
|
@ -131,6 +137,7 @@ class InferenceEngine:
|
||||||
top_p: float = 1.0,
|
top_p: float = 1.0,
|
||||||
top_k: int = 50,
|
top_k: int = 50,
|
||||||
) -> Union[Generator, str, List[str]]:
|
) -> Union[Generator, str, List[str]]:
|
||||||
|
_validate_sampling_params(top_k, top_p, temperature, max_tokens)
|
||||||
is_batch = isinstance(prompt, list)
|
is_batch = isinstance(prompt, list)
|
||||||
prompts = prompt if is_batch else [prompt]
|
prompts = prompt if is_batch else [prompt]
|
||||||
|
|
||||||
|
|
@ -151,6 +158,7 @@ class InferenceEngine:
|
||||||
top_p: float = 1.0,
|
top_p: float = 1.0,
|
||||||
top_k: int = 50,
|
top_k: int = 50,
|
||||||
) -> AsyncGenerator[str, None]:
|
) -> AsyncGenerator[str, None]:
|
||||||
|
_validate_sampling_params(top_k, top_p, temperature, max_tokens)
|
||||||
sync_gen = self._generate_streaming(
|
sync_gen = self._generate_streaming(
|
||||||
[prompt], False, max_tokens, temperature, top_p, top_k
|
[prompt], False, max_tokens, temperature, top_p, top_k
|
||||||
)
|
)
|
||||||
|
|
|
||||||
|
|
@ -2,15 +2,16 @@
|
||||||
AutoModel base class for model loading and saving.
|
AutoModel base class for model loading and saving.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
import json
|
||||||
from contextlib import contextmanager
|
from contextlib import contextmanager
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Self, Union
|
from typing import Self, Union
|
||||||
|
|
||||||
|
import safetensors.torch as st
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
|
|
||||||
from astrai.config.model_config import BaseModelConfig, ConfigFactory
|
from astrai.config.model_config import BaseModelConfig, ConfigFactory
|
||||||
from astrai.factory import BaseFactory
|
from astrai.factory import BaseFactory
|
||||||
from astrai.serialization import load_model_config, load_model_weights, save_model
|
|
||||||
|
|
||||||
|
|
||||||
@contextmanager
|
@contextmanager
|
||||||
|
|
@ -59,22 +60,25 @@ class AutoModel(BaseFactory["AutoModel"], nn.Module):
|
||||||
|
|
||||||
model_path = Path(path)
|
model_path = Path(path)
|
||||||
|
|
||||||
|
# Load config
|
||||||
config_path = model_path / "config.json"
|
config_path = model_path / "config.json"
|
||||||
if not config_path.exists():
|
if config_path.exists():
|
||||||
raise FileNotFoundError(f"Config file not found: {config_path}")
|
with open(config_path, "r") as f:
|
||||||
|
raw = json.load(f)
|
||||||
raw = load_model_config(str(model_path))
|
|
||||||
config = ConfigFactory.load(raw)
|
config = ConfigFactory.load(raw)
|
||||||
model_type = config.model_type or "autoregressive_lm"
|
model_type = config.model_type or "autoregressive_lm"
|
||||||
|
else:
|
||||||
|
raise FileNotFoundError(f"Config file not found: {config_path}")
|
||||||
|
|
||||||
actual_cls = AutoModel.get_component_class(model_type)
|
actual_cls = AutoModel.get_component_class(model_type)
|
||||||
|
|
||||||
with _disable_random_init(enable=disable_random_init):
|
with _disable_random_init(enable=disable_random_init):
|
||||||
model = actual_cls(config)
|
model = actual_cls(config)
|
||||||
|
|
||||||
|
# Load weights
|
||||||
weights_path = model_path / "model.safetensors"
|
weights_path = model_path / "model.safetensors"
|
||||||
if weights_path.exists():
|
if weights_path.exists():
|
||||||
state_dict = load_model_weights(str(model_path))
|
state_dict = st.load_file(str(weights_path))
|
||||||
model.load_state_dict(state_dict, strict=strict)
|
model.load_state_dict(state_dict, strict=strict)
|
||||||
|
|
||||||
return model
|
return model
|
||||||
|
|
@ -83,11 +87,14 @@ class AutoModel(BaseFactory["AutoModel"], nn.Module):
|
||||||
self,
|
self,
|
||||||
save_directory: Union[str, Path],
|
save_directory: Union[str, Path],
|
||||||
) -> None:
|
) -> None:
|
||||||
save_model(
|
save_path = Path(save_directory)
|
||||||
config=self.config.to_dict(),
|
save_path.mkdir(parents=True, exist_ok=True)
|
||||||
state_dict=self.state_dict(),
|
|
||||||
save_directory=str(save_directory),
|
# Save config
|
||||||
)
|
self.config.to_file(str(save_path / "config.json"))
|
||||||
|
|
||||||
|
# Save weights
|
||||||
|
st.save_file(self.state_dict(), str(save_path / "model.safetensors"))
|
||||||
|
|
||||||
def to(self, *args, **kwargs) -> Self:
|
def to(self, *args, **kwargs) -> Self:
|
||||||
"""Move model to device/dtype."""
|
"""Move model to device/dtype."""
|
||||||
|
|
|
||||||
|
|
@ -1,19 +1,15 @@
|
||||||
|
import json
|
||||||
import logging
|
import logging
|
||||||
from dataclasses import asdict, dataclass
|
from dataclasses import asdict, dataclass
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Optional, Set
|
from typing import Optional, Set
|
||||||
|
|
||||||
|
import safetensors.torch as st
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
|
|
||||||
from astrai.model.components.linear import Linear
|
from astrai.model.components.linear import Linear
|
||||||
from astrai.serialization import (
|
|
||||||
load_json,
|
|
||||||
load_safetensors,
|
|
||||||
save_json,
|
|
||||||
save_safetensors,
|
|
||||||
)
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
@ -132,14 +128,16 @@ def save_lora(model: nn.Module, save_dir: str, config: LoRAConfig):
|
||||||
|
|
||||||
path = Path(save_dir)
|
path = Path(save_dir)
|
||||||
path.mkdir(parents=True, exist_ok=True)
|
path.mkdir(parents=True, exist_ok=True)
|
||||||
save_safetensors(lora_sd, path / "adapter_model.safetensors")
|
st.save_file(lora_sd, str(path / "adapter_model.safetensors"))
|
||||||
save_json(asdict(config), path / "adapter_config.json")
|
with open(path / "adapter_config.json", "w") as f:
|
||||||
|
json.dump(asdict(config), f, indent=2)
|
||||||
logger.info("LoRA adapter saved to %s (%d keys)", save_dir, len(lora_sd))
|
logger.info("LoRA adapter saved to %s (%d keys)", save_dir, len(lora_sd))
|
||||||
|
|
||||||
|
|
||||||
def load_lora(model: nn.Module, load_dir: str) -> LoRAConfig:
|
def load_lora(model: nn.Module, load_dir: str) -> LoRAConfig:
|
||||||
path = Path(load_dir)
|
path = Path(load_dir)
|
||||||
raw = load_json(path / "adapter_config.json")
|
with open(path / "adapter_config.json") as f:
|
||||||
|
raw = json.load(f)
|
||||||
config = LoRAConfig(
|
config = LoRAConfig(
|
||||||
r=raw["r"], alpha=raw["alpha"], target_modules=tuple(raw["target_modules"])
|
r=raw["r"], alpha=raw["alpha"], target_modules=tuple(raw["target_modules"])
|
||||||
)
|
)
|
||||||
|
|
@ -159,7 +157,7 @@ def load_lora(model: nn.Module, load_dir: str) -> LoRAConfig:
|
||||||
target_modules=set(config.target_modules),
|
target_modules=set(config.target_modules),
|
||||||
)
|
)
|
||||||
|
|
||||||
weights = load_safetensors(path / "adapter_model.safetensors")
|
weights = st.load_file(str(path / "adapter_model.safetensors"))
|
||||||
try:
|
try:
|
||||||
missing, unexpected = model.load_state_dict(weights, strict=False)
|
missing, unexpected = model.load_state_dict(weights, strict=False)
|
||||||
except RuntimeError as e:
|
except RuntimeError as e:
|
||||||
|
|
|
||||||
|
|
@ -1,8 +1,7 @@
|
||||||
import json
|
import json
|
||||||
import time
|
import time
|
||||||
from dataclasses import dataclass, field
|
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Any, Dict
|
from typing import Any, Dict, Optional
|
||||||
|
|
||||||
import safetensors.torch as st
|
import safetensors.torch as st
|
||||||
import torch
|
import torch
|
||||||
|
|
@ -10,101 +9,75 @@ import torch.distributed as dist
|
||||||
|
|
||||||
from astrai.parallel.setup import get_rank
|
from astrai.parallel.setup import get_rank
|
||||||
|
|
||||||
_META_FILE = "meta.json"
|
|
||||||
_WEIGHTS_FILE = "model.safetensors"
|
|
||||||
_MODEL_CONFIG_FILE = "config.json"
|
|
||||||
|
|
||||||
|
|
||||||
def save_safetensors(state_dict: dict, path: str | Path) -> None:
|
|
||||||
st.save_file(state_dict, str(path))
|
|
||||||
|
|
||||||
|
|
||||||
def load_safetensors(path: str | Path) -> dict:
|
|
||||||
return st.load_file(str(path))
|
|
||||||
|
|
||||||
|
|
||||||
def save_json(data: dict, path: str | Path) -> None:
|
|
||||||
with open(str(path), "w") as f:
|
|
||||||
json.dump(data, f, indent=2)
|
|
||||||
|
|
||||||
|
|
||||||
def load_json(path: str | Path) -> dict:
|
|
||||||
with open(str(path), "r") as f:
|
|
||||||
return json.load(f)
|
|
||||||
|
|
||||||
|
|
||||||
def save_torch(obj: Any, path: str | Path) -> None:
|
|
||||||
torch.save(obj, str(path))
|
|
||||||
|
|
||||||
|
|
||||||
def load_torch(path: str | Path) -> Any:
|
|
||||||
return torch.load(str(path), map_location="cpu", weights_only=False)
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class Checkpoint:
|
class Checkpoint:
|
||||||
state_dict: Dict[str, Any] = field(default_factory=dict)
|
def __init__(
|
||||||
epoch: int = 0
|
self,
|
||||||
iteration: int = 0
|
state_dict: Dict[str, Any],
|
||||||
extra: Dict[str, Any] = field(default_factory=dict)
|
epoch: int = 0,
|
||||||
meta: Dict[str, Any] = field(default_factory=dict)
|
iteration: int = 0,
|
||||||
|
extra: Optional[Dict[str, Any]] = None,
|
||||||
|
meta: Optional[Dict[str, Any]] = None,
|
||||||
|
):
|
||||||
|
self.state_dict = state_dict
|
||||||
|
self.epoch = epoch
|
||||||
|
self.iteration = iteration
|
||||||
|
self.extra = extra or {}
|
||||||
|
self.meta = meta or {}
|
||||||
|
|
||||||
|
def save(
|
||||||
|
self,
|
||||||
|
save_dir: str,
|
||||||
|
) -> None:
|
||||||
|
|
||||||
def save(self, save_dir: str) -> None:
|
|
||||||
save_path = Path(save_dir)
|
save_path = Path(save_dir)
|
||||||
save_path.mkdir(parents=True, exist_ok=True)
|
save_path.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
if get_rank() != 0:
|
rank = get_rank()
|
||||||
return
|
if rank == 0:
|
||||||
|
|
||||||
meta = {
|
meta = {
|
||||||
"epoch": self.epoch,
|
"epoch": self.epoch,
|
||||||
"iteration": self.iteration,
|
"iteration": self.iteration,
|
||||||
"timestamp": time.strftime("%Y-%m-%dT%H:%M:%S"),
|
"timestamp": time.strftime("%Y-%m-%dT%H:%M:%S"),
|
||||||
**self.meta,
|
|
||||||
}
|
}
|
||||||
save_json(meta, save_path / _META_FILE)
|
meta.update(self.meta)
|
||||||
save_safetensors(self.state_dict, save_path / _WEIGHTS_FILE)
|
with open(save_path / "meta.json", "w") as f:
|
||||||
|
json.dump(meta, f, indent=2)
|
||||||
|
|
||||||
|
st.save_file(self.state_dict, save_path / "state_dict.safetensors")
|
||||||
|
if self.extra:
|
||||||
for key, value in self.extra.items():
|
for key, value in self.extra.items():
|
||||||
save_torch(value, save_path / f"{key}.pt")
|
torch.save(value, save_path / f"{key}.pt")
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def load(cls, save_dir: str) -> "Checkpoint":
|
def load(
|
||||||
|
cls,
|
||||||
|
save_dir: str,
|
||||||
|
) -> "Checkpoint":
|
||||||
|
|
||||||
|
rank = get_rank()
|
||||||
save_path = Path(save_dir)
|
save_path = Path(save_dir)
|
||||||
|
|
||||||
meta = {}
|
meta = {}
|
||||||
if get_rank() == 0:
|
if rank == 0:
|
||||||
meta = load_json(save_path / _META_FILE)
|
with open(Path(save_dir) / "meta.json", "r") as f:
|
||||||
|
meta = json.load(f)
|
||||||
|
|
||||||
if dist.is_initialized():
|
if dist.is_initialized():
|
||||||
meta_list = [meta]
|
meta_list = [meta]
|
||||||
dist.broadcast_object_list(meta_list, src=0)
|
dist.broadcast_object_list(meta_list, src=0)
|
||||||
meta = meta_list[0]
|
meta = meta_list[0]
|
||||||
|
|
||||||
state_dict = load_safetensors(save_path / _WEIGHTS_FILE)
|
state_dict = st.load_file(save_path / "state_dict.safetensors")
|
||||||
|
|
||||||
extra = {}
|
extra = {}
|
||||||
for f in save_path.iterdir():
|
for f in save_path.iterdir():
|
||||||
if f.suffix == ".pt":
|
if f.suffix == ".pt" and f.stem not in ("meta",):
|
||||||
extra[f.stem] = load_torch(f)
|
extra[f.stem] = torch.load(f, map_location="cpu", weights_only=False)
|
||||||
|
|
||||||
return cls(
|
return cls(
|
||||||
state_dict=state_dict,
|
state_dict=state_dict,
|
||||||
epoch=meta.get("epoch", 0),
|
epoch=meta["epoch"],
|
||||||
iteration=meta.get("iteration", 0),
|
iteration=meta["iteration"],
|
||||||
extra=extra,
|
extra=extra or None,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def save_model(config: dict, state_dict: dict, save_directory: str) -> None:
|
|
||||||
save_path = Path(save_directory)
|
|
||||||
save_path.mkdir(parents=True, exist_ok=True)
|
|
||||||
save_json(config, save_path / _MODEL_CONFIG_FILE)
|
|
||||||
save_safetensors(state_dict, save_path / _WEIGHTS_FILE)
|
|
||||||
|
|
||||||
|
|
||||||
def load_model_config(save_directory: str) -> dict:
|
|
||||||
return load_json(Path(save_directory) / _MODEL_CONFIG_FILE)
|
|
||||||
|
|
||||||
|
|
||||||
def load_model_weights(save_directory: str) -> dict:
|
|
||||||
return load_safetensors(Path(save_directory) / _WEIGHTS_FILE)
|
|
||||||
|
|
|
||||||
|
|
@ -210,7 +210,7 @@ class ProgressBarCallback(TrainCallback):
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self, num_epoch: int, log_interval: int = 100, file: Optional[IO[str]] = None
|
self, num_epoch: int, log_interval: int = 100, file: IO[str] = sys.stdout
|
||||||
):
|
):
|
||||||
self.num_epoch = num_epoch
|
self.num_epoch = num_epoch
|
||||||
self.log_interval = log_interval
|
self.log_interval = log_interval
|
||||||
|
|
@ -223,7 +223,7 @@ class ProgressBarCallback(TrainCallback):
|
||||||
context.dataloader,
|
context.dataloader,
|
||||||
desc=f"Epoch {context.epoch + 1}/{self.num_epoch}",
|
desc=f"Epoch {context.epoch + 1}/{self.num_epoch}",
|
||||||
dynamic_ncols=True,
|
dynamic_ncols=True,
|
||||||
file=self.file or sys.stdout,
|
file=self.file,
|
||||||
)
|
)
|
||||||
|
|
||||||
@only_on_rank(0)
|
@only_on_rank(0)
|
||||||
|
|
|
||||||
|
|
@ -71,7 +71,6 @@ class TrainContextBuilder:
|
||||||
if self._checkpoint is not None:
|
if self._checkpoint is not None:
|
||||||
context.epoch = max(self._checkpoint.epoch, cfg.start_epoch)
|
context.epoch = max(self._checkpoint.epoch, cfg.start_epoch)
|
||||||
context.iteration = max(self._checkpoint.iteration, cfg.start_batch)
|
context.iteration = max(self._checkpoint.iteration, cfg.start_batch)
|
||||||
if self._checkpoint.state_dict:
|
|
||||||
context.model.load_state_dict(self._checkpoint.state_dict)
|
context.model.load_state_dict(self._checkpoint.state_dict)
|
||||||
context.checkpoint = self._checkpoint
|
context.checkpoint = self._checkpoint
|
||||||
else:
|
else:
|
||||||
|
|
|
||||||
|
|
@ -2,13 +2,13 @@ import argparse
|
||||||
import os
|
import os
|
||||||
from functools import partial
|
from functools import partial
|
||||||
|
|
||||||
|
import safetensors.torch as st
|
||||||
import torch
|
import torch
|
||||||
import torch.optim as optim
|
import torch.optim as optim
|
||||||
|
|
||||||
from astrai.config import AutoRegressiveLMConfig, TrainConfig
|
from astrai.config import AutoRegressiveLMConfig, TrainConfig
|
||||||
from astrai.dataset import DatasetFactory
|
from astrai.dataset import DatasetFactory
|
||||||
from astrai.model import AutoRegressiveLM
|
from astrai.model import AutoRegressiveLM
|
||||||
from astrai.serialization import Checkpoint
|
|
||||||
from astrai.trainer import SchedulerFactory, Trainer
|
from astrai.trainer import SchedulerFactory, Trainer
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -236,14 +236,16 @@ def train(
|
||||||
if window_size is None:
|
if window_size is None:
|
||||||
window_size = config.max_len
|
window_size = config.max_len
|
||||||
|
|
||||||
# Create model and load full checkpoint (state_dict + optimizer + scheduler + meta)
|
# Create bare AutoRegressiveLM (for training, no tokenizer needed)
|
||||||
checkpoint = Checkpoint.load(param_path)
|
model = AutoRegressiveLM(config)
|
||||||
model = AutoRegressiveLM(config).to(dtype=torch.bfloat16)
|
|
||||||
model.load_state_dict(checkpoint.state_dict, strict=False)
|
|
||||||
|
|
||||||
# Strip state_dict to avoid pickling ~7GB through mp.spawn pipe
|
# Load weights if available
|
||||||
# (model weights already loaded into model above)
|
weights_path = os.path.join(param_path, "model.safetensors")
|
||||||
checkpoint.state_dict = {}
|
if os.path.exists(weights_path):
|
||||||
|
state_dict = st.load_file(weights_path)
|
||||||
|
model.load_state_dict(state_dict, strict=False)
|
||||||
|
|
||||||
|
model = model.to(dtype=torch.bfloat16)
|
||||||
|
|
||||||
strategy_kwargs = {
|
strategy_kwargs = {
|
||||||
"beta": dpo_beta,
|
"beta": dpo_beta,
|
||||||
|
|
@ -255,6 +257,8 @@ def train(
|
||||||
}
|
}
|
||||||
|
|
||||||
executor_kwargs = {
|
executor_kwargs = {
|
||||||
|
"static_graph": True,
|
||||||
|
"find_unused_parameters": False,
|
||||||
"gradient_as_bucket_view": True,
|
"gradient_as_bucket_view": True,
|
||||||
"broadcast_buffers": False,
|
"broadcast_buffers": False,
|
||||||
}
|
}
|
||||||
|
|
@ -315,7 +319,7 @@ def train(
|
||||||
)
|
)
|
||||||
|
|
||||||
trainer = Trainer(train_config)
|
trainer = Trainer(train_config)
|
||||||
trainer.train(checkpoint=checkpoint)
|
trainer.train()
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
|
|
||||||
|
|
@ -1,4 +1,3 @@
|
||||||
import os
|
|
||||||
import tempfile
|
import tempfile
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
@ -37,6 +36,7 @@ def test_single_process():
|
||||||
|
|
||||||
|
|
||||||
def test_checkpoint_with_extra():
|
def test_checkpoint_with_extra():
|
||||||
|
"""Verify extra keys are saved as individual .pt files and loaded back."""
|
||||||
model = torch.nn.Linear(10, 5)
|
model = torch.nn.Linear(10, 5)
|
||||||
optimizer = AdamW(model.parameters(), lr=1e-3)
|
optimizer = AdamW(model.parameters(), lr=1e-3)
|
||||||
optimizer.step()
|
optimizer.step()
|
||||||
|
|
@ -52,6 +52,8 @@ def test_checkpoint_with_extra():
|
||||||
with tempfile.TemporaryDirectory() as tmpdir:
|
with tempfile.TemporaryDirectory() as tmpdir:
|
||||||
checkpoint.save(tmpdir)
|
checkpoint.save(tmpdir)
|
||||||
|
|
||||||
|
import os
|
||||||
|
|
||||||
assert os.path.exists(os.path.join(tmpdir, "optimizer.pt"))
|
assert os.path.exists(os.path.join(tmpdir, "optimizer.pt"))
|
||||||
assert os.path.exists(os.path.join(tmpdir, "scheduler.pt"))
|
assert os.path.exists(os.path.join(tmpdir, "scheduler.pt"))
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -1,286 +0,0 @@
|
||||||
"""Unit tests for protocol builders, StopChecker, GenContext, StopInfo."""
|
|
||||||
|
|
||||||
import json
|
|
||||||
from unittest.mock import MagicMock
|
|
||||||
|
|
||||||
import pytest
|
|
||||||
|
|
||||||
from astrai.inference.api.anthropic import AnthropicResponseBuilder
|
|
||||||
from astrai.inference.api.openai import OpenAIResponseBuilder
|
|
||||||
from astrai.inference.api.protocol import GenContext, StopChecker, StopInfo
|
|
||||||
from astrai.inference.engine import GenerationRequest
|
|
||||||
|
|
||||||
|
|
||||||
def _make_ctx(**kwargs):
|
|
||||||
defaults = {
|
|
||||||
"resp_id": "test-123",
|
|
||||||
"created": 1000,
|
|
||||||
"model": "test-model",
|
|
||||||
"prompt_tokens": 10,
|
|
||||||
"completion_tokens": 5,
|
|
||||||
}
|
|
||||||
defaults.update(kwargs)
|
|
||||||
return GenContext(**defaults)
|
|
||||||
|
|
||||||
|
|
||||||
def _sse_payloads(events):
|
|
||||||
payloads = []
|
|
||||||
for chunk in events:
|
|
||||||
for line in chunk.strip().split("\n"):
|
|
||||||
if line.startswith("data: "):
|
|
||||||
try:
|
|
||||||
payloads.append(json.loads(line[6:]))
|
|
||||||
except json.JSONDecodeError:
|
|
||||||
pass
|
|
||||||
return payloads
|
|
||||||
|
|
||||||
|
|
||||||
class TestStopChecker:
|
|
||||||
def test_check_finds_match(self):
|
|
||||||
sc = StopChecker(["stop", "end"])
|
|
||||||
assert sc.check("hello stop world") == "stop"
|
|
||||||
|
|
||||||
def test_check_returns_none_when_no_match(self):
|
|
||||||
sc = StopChecker(["stop"])
|
|
||||||
assert sc.check("hello world") is None
|
|
||||||
|
|
||||||
def test_check_empty_sequences(self):
|
|
||||||
sc = StopChecker([])
|
|
||||||
assert sc.check("hello") is None
|
|
||||||
|
|
||||||
|
|
||||||
class TestGenContext:
|
|
||||||
def test_defaults(self):
|
|
||||||
ctx = GenContext(resp_id="a", created=1, model="m", prompt_tokens=10)
|
|
||||||
assert ctx.completion_tokens == 0
|
|
||||||
|
|
||||||
def test_fields_mutable(self):
|
|
||||||
ctx = GenContext(resp_id="a", created=1, model="m", prompt_tokens=10)
|
|
||||||
ctx.completion_tokens = 42
|
|
||||||
assert ctx.completion_tokens == 42
|
|
||||||
|
|
||||||
|
|
||||||
class TestStopInfo:
|
|
||||||
def test_defaults(self):
|
|
||||||
s = StopInfo()
|
|
||||||
assert s.matched is None
|
|
||||||
assert s.body == ""
|
|
||||||
assert s.yielded == ""
|
|
||||||
|
|
||||||
def test_with_values(self):
|
|
||||||
s = StopInfo(matched="stop", body="hello stop", yielded="hello ")
|
|
||||||
assert s.matched == "stop"
|
|
||||||
assert s.body == "hello stop"
|
|
||||||
assert s.yielded == "hello "
|
|
||||||
|
|
||||||
|
|
||||||
class TestOpenAIResponseBuilder:
|
|
||||||
@pytest.fixture
|
|
||||||
def builder(self):
|
|
||||||
builder = OpenAIResponseBuilder()
|
|
||||||
req = MagicMock()
|
|
||||||
req.messages = [MagicMock(role="user", content="Hello")]
|
|
||||||
req.stop = None
|
|
||||||
req.model = "astrai"
|
|
||||||
engine = MagicMock()
|
|
||||||
engine.tokenizer.apply_chat_template.return_value = "Hello"
|
|
||||||
builder.prepare(req, engine)
|
|
||||||
return builder
|
|
||||||
|
|
||||||
def test_prepare_returns_prompt_ctx_stops(self, builder):
|
|
||||||
req = MagicMock()
|
|
||||||
req.messages = [MagicMock(role="user", content="Hi")]
|
|
||||||
req.stop = ["END"]
|
|
||||||
req.model = "gpt"
|
|
||||||
engine = MagicMock()
|
|
||||||
engine.tokenizer.apply_chat_template.return_value = "Hi"
|
|
||||||
prompt, ctx, stops = builder.prepare(req, engine)
|
|
||||||
assert prompt == "Hi"
|
|
||||||
assert ctx.model == "gpt"
|
|
||||||
assert ctx.prompt_tokens == 0
|
|
||||||
assert stops == ["END"]
|
|
||||||
|
|
||||||
def test_prepare_no_stop_returns_empty_list(self, builder):
|
|
||||||
req = MagicMock()
|
|
||||||
req.messages = []
|
|
||||||
req.stop = None
|
|
||||||
req.model = "x"
|
|
||||||
engine = MagicMock()
|
|
||||||
engine.tokenizer.apply_chat_template.return_value = ""
|
|
||||||
_, _, stops = builder.prepare(req, engine)
|
|
||||||
assert stops == []
|
|
||||||
|
|
||||||
def test_format_stream_start(self, builder):
|
|
||||||
ctx = _make_ctx()
|
|
||||||
events = builder.format_stream_start(ctx)
|
|
||||||
payloads = _sse_payloads(events)
|
|
||||||
assert len(payloads) == 1
|
|
||||||
p = payloads[0]
|
|
||||||
assert p["object"] == "chat.completion.chunk"
|
|
||||||
assert p["choices"][0]["delta"]["role"] == "assistant"
|
|
||||||
assert p["choices"][0]["finish_reason"] is None
|
|
||||||
|
|
||||||
def test_format_chunk(self, builder):
|
|
||||||
event = builder.format_chunk("hello")
|
|
||||||
payload = json.loads(event.split("data: ", 1)[1])
|
|
||||||
assert payload["choices"][0]["delta"]["content"] == "hello"
|
|
||||||
assert payload["choices"][0]["finish_reason"] is None
|
|
||||||
|
|
||||||
def test_format_stream_end(self, builder):
|
|
||||||
ctx = _make_ctx(completion_tokens=5)
|
|
||||||
stop = StopInfo(matched="stop")
|
|
||||||
events = builder.format_stream_end(ctx, stop)
|
|
||||||
payloads = _sse_payloads(events)
|
|
||||||
finish = payloads[0]
|
|
||||||
assert finish["choices"][0]["finish_reason"] == "stop"
|
|
||||||
usage = payloads[1]
|
|
||||||
assert usage["completion_tokens"] == 5
|
|
||||||
assert usage["total_tokens"] == 15
|
|
||||||
|
|
||||||
def test_format_response(self, builder):
|
|
||||||
ctx = _make_ctx()
|
|
||||||
stop = StopInfo()
|
|
||||||
resp = builder.format_response(ctx, "hello", stop)
|
|
||||||
assert resp["object"] == "chat.completion"
|
|
||||||
assert resp["choices"][0]["message"]["content"] == "hello"
|
|
||||||
assert resp["usage"]["prompt_tokens"] == 10
|
|
||||||
|
|
||||||
|
|
||||||
class TestAnthropicResponseBuilder:
|
|
||||||
@pytest.fixture
|
|
||||||
def builder(self):
|
|
||||||
builder = AnthropicResponseBuilder()
|
|
||||||
req = MagicMock()
|
|
||||||
req.messages = [MagicMock(role="user", content="Hello")]
|
|
||||||
req.model = "claude"
|
|
||||||
engine = MagicMock()
|
|
||||||
engine.tokenizer.apply_chat_template.return_value = "Hello"
|
|
||||||
req.system = None
|
|
||||||
builder.prepare(req, engine)
|
|
||||||
return builder
|
|
||||||
|
|
||||||
def test_prepare_messages(self, builder):
|
|
||||||
req = MagicMock()
|
|
||||||
req.messages = [MagicMock(role="user", content="Hi")]
|
|
||||||
req.model = "claude"
|
|
||||||
req.system = None
|
|
||||||
req.stop_sequences = None
|
|
||||||
engine = MagicMock()
|
|
||||||
engine.tokenizer.apply_chat_template.return_value = "Hi"
|
|
||||||
prompt, ctx, stops = builder.prepare(req, engine)
|
|
||||||
assert prompt == "Hi"
|
|
||||||
assert stops == []
|
|
||||||
|
|
||||||
def test_prepare_with_stop_sequences(self, builder):
|
|
||||||
req = MagicMock()
|
|
||||||
req.messages = []
|
|
||||||
req.model = "x"
|
|
||||||
req.stop_sequences = ["stop", "end"]
|
|
||||||
req.system = None
|
|
||||||
engine = MagicMock()
|
|
||||||
engine.tokenizer.apply_chat_template.return_value = ""
|
|
||||||
_, _, stops = builder.prepare(req, engine)
|
|
||||||
assert stops == ["stop", "end"]
|
|
||||||
|
|
||||||
def test_format_stream_start(self, builder):
|
|
||||||
ctx = _make_ctx(prompt_tokens=3)
|
|
||||||
events = builder.format_stream_start(ctx)
|
|
||||||
payloads = _sse_payloads(events)
|
|
||||||
assert len(payloads) == 2
|
|
||||||
assert payloads[0]["type"] == "message_start"
|
|
||||||
assert payloads[0]["message"]["usage"]["input_tokens"] == 3
|
|
||||||
assert payloads[1]["type"] == "content_block_start"
|
|
||||||
|
|
||||||
def test_format_chunk(self, builder):
|
|
||||||
event = builder.format_chunk("tok")
|
|
||||||
payload = json.loads(event.split("data: ", 1)[1])
|
|
||||||
assert payload["type"] == "content_block_delta"
|
|
||||||
assert payload["delta"]["text"] == "tok"
|
|
||||||
|
|
||||||
def test_format_stream_end_no_stop(self, builder):
|
|
||||||
ctx = _make_ctx(completion_tokens=3)
|
|
||||||
stop = StopInfo()
|
|
||||||
events = builder.format_stream_end(ctx, stop)
|
|
||||||
payloads = _sse_payloads(events)
|
|
||||||
# content_block_stop, message_delta, message_stop
|
|
||||||
types = [p["type"] for p in payloads]
|
|
||||||
assert types == ["content_block_stop", "message_delta", "message_stop"]
|
|
||||||
assert payloads[1]["delta"]["stop_reason"] == "end_turn"
|
|
||||||
|
|
||||||
def test_format_stream_end_with_stop_trims_and_emits_remaining(self, builder):
|
|
||||||
ctx = _make_ctx(completion_tokens=7)
|
|
||||||
stop = StopInfo(
|
|
||||||
matched="END",
|
|
||||||
body="Hello world END extra",
|
|
||||||
yielded="Hello ",
|
|
||||||
)
|
|
||||||
events = builder.format_stream_end(ctx, stop)
|
|
||||||
payloads = _sse_payloads(events)
|
|
||||||
# unyielded delta, content_block_stop, message_delta, message_stop
|
|
||||||
types = [p["type"] for p in payloads]
|
|
||||||
assert types == [
|
|
||||||
"content_block_delta",
|
|
||||||
"content_block_stop",
|
|
||||||
"message_delta",
|
|
||||||
"message_stop",
|
|
||||||
]
|
|
||||||
assert payloads[0]["delta"]["text"] == "world "
|
|
||||||
assert payloads[2]["delta"]["stop_reason"] == "stop_sequence"
|
|
||||||
assert payloads[2]["delta"]["stop_sequence"] == "END"
|
|
||||||
|
|
||||||
def test_format_stream_end_stop_trimmed_already_yielded(self, builder):
|
|
||||||
ctx = _make_ctx()
|
|
||||||
stop = StopInfo(
|
|
||||||
matched="END",
|
|
||||||
body="Hello END",
|
|
||||||
yielded="Hello ",
|
|
||||||
)
|
|
||||||
events = builder.format_stream_end(ctx, stop)
|
|
||||||
payloads = _sse_payloads(events)
|
|
||||||
# No unyielded delta (everything already sent)
|
|
||||||
types = [p["type"] for p in payloads]
|
|
||||||
assert types == ["content_block_stop", "message_delta", "message_stop"]
|
|
||||||
|
|
||||||
def test_format_response_with_stop_trims_content(self, builder):
|
|
||||||
ctx = _make_ctx()
|
|
||||||
stop = StopInfo(matched="STOP", body="text STOP extra", yielded="text ")
|
|
||||||
resp = builder.format_response(ctx, "text STOP extra", stop)
|
|
||||||
assert resp["content"][0]["text"] == "text "
|
|
||||||
assert resp["stop_reason"] == "stop_sequence"
|
|
||||||
assert resp["stop_sequence"] == "STOP"
|
|
||||||
|
|
||||||
def test_format_response_no_stop(self, builder):
|
|
||||||
ctx = _make_ctx()
|
|
||||||
stop = StopInfo()
|
|
||||||
resp = builder.format_response(ctx, "full text", stop)
|
|
||||||
assert resp["content"][0]["text"] == "full text"
|
|
||||||
assert resp["stop_reason"] == "end_turn"
|
|
||||||
|
|
||||||
|
|
||||||
class TestGenerationRequestValidation:
|
|
||||||
def test_valid_params(self):
|
|
||||||
gr = GenerationRequest(
|
|
||||||
messages=[{"role": "user", "content": "hi"}],
|
|
||||||
top_k=50,
|
|
||||||
top_p=0.9,
|
|
||||||
temperature=0.7,
|
|
||||||
)
|
|
||||||
assert gr.top_k == 50
|
|
||||||
|
|
||||||
def test_invalid_top_p_raises(self):
|
|
||||||
with pytest.raises(ValueError, match="top_p"):
|
|
||||||
GenerationRequest(messages=[{"role": "user", "content": "hi"}], top_p=1.5)
|
|
||||||
|
|
||||||
def test_invalid_top_k_raises(self):
|
|
||||||
with pytest.raises(ValueError, match="top_k"):
|
|
||||||
GenerationRequest(messages=[{"role": "user", "content": "hi"}], top_k=-1)
|
|
||||||
|
|
||||||
def test_invalid_temperature_raises(self):
|
|
||||||
with pytest.raises(ValueError, match="temperature"):
|
|
||||||
GenerationRequest(
|
|
||||||
messages=[{"role": "user", "content": "hi"}], temperature=-0.1
|
|
||||||
)
|
|
||||||
|
|
||||||
def test_top_k_zero_valid(self):
|
|
||||||
gr = GenerationRequest(messages=[{"role": "user", "content": "hi"}], top_k=0)
|
|
||||||
assert gr.top_k == 0
|
|
||||||
|
|
@ -173,21 +173,3 @@ def test_scheduler_concurrent_get_stats(mock_model_and_tokenizer):
|
||||||
for stats in results["stats"]:
|
for stats in results["stats"]:
|
||||||
assert "total_tasks" in stats
|
assert "total_tasks" in stats
|
||||||
assert stats["total_tasks"] >= 0
|
assert stats["total_tasks"] >= 0
|
||||||
|
|
||||||
|
|
||||||
def test_prefill_skips_fully_cached_tasks(mock_model_and_tokenizer):
|
|
||||||
"""Tasks whose entire prompt is cached skip the prefill phase."""
|
|
||||||
mock_model, mock_tokenizer = mock_model_and_tokenizer
|
|
||||||
|
|
||||||
with patch("astrai.inference.core.scheduler.AutoModel"):
|
|
||||||
with patch("astrai.inference.core.scheduler.AutoTokenizer"):
|
|
||||||
scheduler = InferenceScheduler(
|
|
||||||
model=mock_model,
|
|
||||||
tokenizer=mock_tokenizer,
|
|
||||||
max_batch_size=4,
|
|
||||||
device="cpu",
|
|
||||||
)
|
|
||||||
|
|
||||||
task_id = scheduler.add_task("short prompt", stream_callback=lambda t: None)
|
|
||||||
scheduler.stop()
|
|
||||||
assert task_id.startswith("task_")
|
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue