Compare commits
6 Commits
737585a32a
...
65ab69543b
| Author | SHA1 | Date |
|---|---|---|
|
|
65ab69543b | |
|
|
1d26aa2e93 | |
|
|
a548d4553e | |
|
|
dd1b39f435 | |
|
|
94d6e713e9 | |
|
|
47c37e4876 |
|
|
@ -43,6 +43,7 @@ class ResumableDistributedSampler(Sampler[int]):
|
|||
offset = 0 if drop_last else self.num_replicas - 1
|
||||
self.num_samples_per_replica = (self.num_samples + offset) // self.num_replicas
|
||||
self.total_size = self.num_samples_per_replica * self.num_replicas
|
||||
self.iter = self.iter % self.num_samples_per_replica
|
||||
|
||||
self._indices = None
|
||||
|
||||
|
|
@ -74,5 +75,10 @@ class ResumableDistributedSampler(Sampler[int]):
|
|||
self.epoch += 1
|
||||
self._indices = None
|
||||
|
||||
@property
|
||||
def _remaining(self):
|
||||
remaining = self.num_samples_per_replica - self.iter
|
||||
return max(remaining, 0)
|
||||
|
||||
def __len__(self):
|
||||
return self.num_samples_per_replica
|
||||
return self._remaining
|
||||
|
|
|
|||
|
|
@ -1,25 +1,27 @@
|
|||
"""Inference module for continuous batching.
|
||||
|
||||
Layers:
|
||||
- core/: Core inference loop (cache, executor, scheduler, task)
|
||||
- api/: HTTP protocol handlers (OpenAI, Anthropic)
|
||||
- engine.py: Facade (InferenceEngine), Value Object (GenerationRequest)
|
||||
- sample.py: Strategy pattern (TemperatureStrategy, TopKStrategy, TopPStrategy)
|
||||
- core/: Core inference loop (cache, executor, scheduler, task)
|
||||
- api/: HTTP orchestration (ProtocolHandler, server)
|
||||
- protocols/: Response builders (OpenAI, Anthropic)
|
||||
- transport/: SSE transport utilities
|
||||
- engine.py: Facade (InferenceEngine), Value Object (GenerationRequest)
|
||||
- sample.py: Strategy pattern (TemperatureStrategy, TopKStrategy, TopPStrategy)
|
||||
"""
|
||||
|
||||
from astrai.inference.api import (
|
||||
AnthropicHandler,
|
||||
AnthropicMessage,
|
||||
ChatCompletionRequest,
|
||||
ChatMessage,
|
||||
GenContext,
|
||||
MessagesRequest,
|
||||
OpenAIHandler,
|
||||
ProtocolHandler,
|
||||
StopChecker,
|
||||
StreamContext,
|
||||
app,
|
||||
run_server,
|
||||
)
|
||||
from astrai.inference.api.anthropic import AnthropicResponseBuilder
|
||||
from astrai.inference.api.openai import OpenAIResponseBuilder
|
||||
from astrai.inference.core import (
|
||||
STOP,
|
||||
Allocator,
|
||||
|
|
@ -36,10 +38,7 @@ from astrai.inference.core import (
|
|||
TaskTable,
|
||||
page_hash,
|
||||
)
|
||||
from astrai.inference.engine import (
|
||||
GenerationRequest,
|
||||
InferenceEngine,
|
||||
)
|
||||
from astrai.inference.engine import GenerationRequest, InferenceEngine
|
||||
from astrai.inference.sample import (
|
||||
BaseSamplingStrategy,
|
||||
SamplingPipeline,
|
||||
|
|
@ -50,17 +49,14 @@ from astrai.inference.sample import (
|
|||
)
|
||||
|
||||
__all__ = [
|
||||
# Engine / Requests
|
||||
"InferenceEngine",
|
||||
"GenerationRequest",
|
||||
# Core scheduler
|
||||
"InferenceScheduler",
|
||||
"Executor",
|
||||
"STOP",
|
||||
"Task",
|
||||
"TaskManager",
|
||||
"TaskStatus",
|
||||
# Core cache
|
||||
"Allocator",
|
||||
"KVCache",
|
||||
"KvcacheView",
|
||||
|
|
@ -69,20 +65,17 @@ __all__ = [
|
|||
"Storage",
|
||||
"TaskTable",
|
||||
"page_hash",
|
||||
# Sampling (Strategy pattern)
|
||||
"sample",
|
||||
"BaseSamplingStrategy",
|
||||
"TemperatureStrategy",
|
||||
"TopKStrategy",
|
||||
"TopPStrategy",
|
||||
"SamplingPipeline",
|
||||
# Protocol
|
||||
"ProtocolHandler",
|
||||
"StopChecker",
|
||||
"StreamContext",
|
||||
"AnthropicHandler",
|
||||
"OpenAIHandler",
|
||||
# Server
|
||||
"GenContext",
|
||||
"OpenAIResponseBuilder",
|
||||
"AnthropicResponseBuilder",
|
||||
"ChatMessage",
|
||||
"ChatCompletionRequest",
|
||||
"AnthropicMessage",
|
||||
|
|
|
|||
|
|
@ -1,12 +1,6 @@
|
|||
"""Inference API: protocol handlers and FastAPI server."""
|
||||
"""Inference API: protocol handler, stop checker, and FastAPI server."""
|
||||
|
||||
from astrai.inference.api.protocol import (
|
||||
AnthropicHandler,
|
||||
OpenAIHandler,
|
||||
ProtocolHandler,
|
||||
StopChecker,
|
||||
StreamContext,
|
||||
)
|
||||
from astrai.inference.api.protocol import GenContext, ProtocolHandler, StopChecker
|
||||
from astrai.inference.api.server import (
|
||||
AnthropicMessage,
|
||||
ChatCompletionRequest,
|
||||
|
|
@ -17,11 +11,9 @@ from astrai.inference.api.server import (
|
|||
)
|
||||
|
||||
__all__ = [
|
||||
"AnthropicHandler",
|
||||
"OpenAIHandler",
|
||||
"ProtocolHandler",
|
||||
"StopChecker",
|
||||
"StreamContext",
|
||||
"GenContext",
|
||||
"AnthropicMessage",
|
||||
"ChatCompletionRequest",
|
||||
"ChatMessage",
|
||||
|
|
|
|||
|
|
@ -0,0 +1,140 @@
|
|||
"""Anthropic message completion response builder."""
|
||||
|
||||
import uuid
|
||||
from typing import Any, Dict, List, Tuple, Union
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from astrai.inference.api.protocol import (
|
||||
GenContext,
|
||||
ResponseBuilder,
|
||||
StopInfo,
|
||||
sse_event,
|
||||
)
|
||||
from astrai.inference.engine import InferenceEngine
|
||||
|
||||
|
||||
def _extract_text(content: Union[str, List[Dict[str, Any]]]) -> str:
|
||||
if isinstance(content, str):
|
||||
return content
|
||||
if isinstance(content, list):
|
||||
for block in content:
|
||||
if isinstance(block, dict) and block.get("type") == "text":
|
||||
return block.get("text", "")
|
||||
return ""
|
||||
|
||||
|
||||
class AnthropicResponseBuilder(ResponseBuilder):
|
||||
def prepare(
|
||||
self, request: BaseModel, engine: InferenceEngine
|
||||
) -> Tuple[str, GenContext, List[str]]:
|
||||
messages: List[Dict[str, str]] = []
|
||||
system = getattr(request, "system", None)
|
||||
if system:
|
||||
messages.append({"role": "system", "content": system})
|
||||
for m in request.messages:
|
||||
text = _extract_text(m.content)
|
||||
if text:
|
||||
messages.append({"role": m.role, "content": text})
|
||||
prompt = engine.tokenizer.apply_chat_template(messages, tokenize=False)
|
||||
ctx = GenContext(
|
||||
resp_id=f"msg_{uuid.uuid4().hex[:24]}",
|
||||
created=0,
|
||||
model=request.model,
|
||||
prompt_tokens=0,
|
||||
)
|
||||
stop_sequences = getattr(request, "stop_sequences", None) or []
|
||||
return prompt, ctx, stop_sequences
|
||||
|
||||
def format_stream_start(self, ctx: GenContext) -> List[str]:
|
||||
return [
|
||||
sse_event(
|
||||
{
|
||||
"type": "message_start",
|
||||
"message": {
|
||||
"id": ctx.resp_id,
|
||||
"type": "message",
|
||||
"role": "assistant",
|
||||
"model": ctx.model,
|
||||
"content": [],
|
||||
"usage": {"input_tokens": ctx.prompt_tokens},
|
||||
},
|
||||
},
|
||||
event="message_start",
|
||||
),
|
||||
sse_event(
|
||||
{
|
||||
"type": "content_block_start",
|
||||
"index": 0,
|
||||
"content_block": {"type": "text", "text": ""},
|
||||
},
|
||||
event="content_block_start",
|
||||
),
|
||||
]
|
||||
|
||||
def format_chunk(self, token: str) -> str:
|
||||
return sse_event(
|
||||
{
|
||||
"type": "content_block_delta",
|
||||
"index": 0,
|
||||
"delta": {"type": "text_delta", "text": token},
|
||||
},
|
||||
event="content_block_delta",
|
||||
)
|
||||
|
||||
def format_stream_end(self, ctx: GenContext, stop: StopInfo) -> List[str]:
|
||||
events: List[str] = []
|
||||
if stop.matched:
|
||||
trimmed = stop.body[: stop.body.rfind(stop.matched)]
|
||||
unyielded = trimmed[len(stop.yielded) :]
|
||||
if unyielded:
|
||||
events.append(
|
||||
sse_event(
|
||||
{
|
||||
"type": "content_block_delta",
|
||||
"index": 0,
|
||||
"delta": {"type": "text_delta", "text": unyielded},
|
||||
},
|
||||
event="content_block_delta",
|
||||
)
|
||||
)
|
||||
events.append(
|
||||
sse_event(
|
||||
{"type": "content_block_stop", "index": 0},
|
||||
event="content_block_stop",
|
||||
)
|
||||
)
|
||||
events.append(
|
||||
sse_event(
|
||||
{
|
||||
"type": "message_delta",
|
||||
"delta": {
|
||||
"stop_reason": "stop_sequence" if stop.matched else "end_turn",
|
||||
"stop_sequence": stop.matched,
|
||||
},
|
||||
"usage": {"output_tokens": ctx.completion_tokens},
|
||||
},
|
||||
event="message_delta",
|
||||
)
|
||||
)
|
||||
events.append(sse_event({"type": "message_stop"}, event="message_stop"))
|
||||
return events
|
||||
|
||||
def format_response(
|
||||
self, ctx: GenContext, content: str, stop: StopInfo
|
||||
) -> Dict[str, Any]:
|
||||
if stop.matched:
|
||||
content = content[: content.rfind(stop.matched)]
|
||||
return {
|
||||
"id": ctx.resp_id,
|
||||
"type": "message",
|
||||
"role": "assistant",
|
||||
"model": ctx.model,
|
||||
"content": [{"type": "text", "text": content}],
|
||||
"stop_reason": "stop_sequence" if stop.matched else "end_turn",
|
||||
"stop_sequence": stop.matched,
|
||||
"usage": {
|
||||
"input_tokens": ctx.prompt_tokens,
|
||||
"output_tokens": ctx.completion_tokens,
|
||||
},
|
||||
}
|
||||
|
|
@ -0,0 +1,111 @@
|
|||
"""OpenAI chat completion response builder."""
|
||||
|
||||
import uuid
|
||||
from typing import Any, Dict, List, Tuple
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from astrai.inference.api.protocol import (
|
||||
GenContext,
|
||||
ResponseBuilder,
|
||||
StopInfo,
|
||||
sse_event,
|
||||
)
|
||||
from astrai.inference.engine import InferenceEngine
|
||||
|
||||
|
||||
class OpenAIResponseBuilder(ResponseBuilder):
|
||||
def prepare(
|
||||
self, request: BaseModel, engine: InferenceEngine
|
||||
) -> Tuple[str, GenContext, List[str]]:
|
||||
messages = [{"role": m.role, "content": m.content} for m in request.messages]
|
||||
prompt = engine.tokenizer.apply_chat_template(messages, tokenize=False)
|
||||
|
||||
self._resp_id = f"chatcmpl-{uuid.uuid4().hex[:12]}"
|
||||
self._model = request.model
|
||||
|
||||
ctx = GenContext(
|
||||
resp_id=self._resp_id,
|
||||
created=0,
|
||||
model=self._model,
|
||||
prompt_tokens=0,
|
||||
)
|
||||
stop = request.stop
|
||||
stop_sequences = (
|
||||
[] if stop is None else [stop] if isinstance(stop, str) else stop
|
||||
)
|
||||
return prompt, ctx, stop_sequences
|
||||
|
||||
def format_stream_start(self, ctx: GenContext) -> List[str]:
|
||||
return [
|
||||
sse_event(
|
||||
{
|
||||
"id": self._resp_id,
|
||||
"object": "chat.completion.chunk",
|
||||
"created": ctx.created,
|
||||
"model": self._model,
|
||||
"choices": [
|
||||
{
|
||||
"index": 0,
|
||||
"delta": {"role": "assistant"},
|
||||
"finish_reason": None,
|
||||
}
|
||||
],
|
||||
}
|
||||
)
|
||||
]
|
||||
|
||||
def format_chunk(self, token: str) -> str:
|
||||
return sse_event(
|
||||
{
|
||||
"id": self._resp_id,
|
||||
"object": "chat.completion.chunk",
|
||||
"created": 0,
|
||||
"model": self._model,
|
||||
"choices": [
|
||||
{"index": 0, "delta": {"content": token}, "finish_reason": None}
|
||||
],
|
||||
}
|
||||
)
|
||||
|
||||
def format_stream_end(self, ctx: GenContext, stop: StopInfo) -> List[str]:
|
||||
return [
|
||||
sse_event(
|
||||
{
|
||||
"id": self._resp_id,
|
||||
"object": "chat.completion.chunk",
|
||||
"created": ctx.created,
|
||||
"model": self._model,
|
||||
"choices": [{"index": 0, "delta": {}, "finish_reason": "stop"}],
|
||||
}
|
||||
),
|
||||
sse_event(
|
||||
{
|
||||
"prompt_tokens": ctx.prompt_tokens,
|
||||
"completion_tokens": ctx.completion_tokens,
|
||||
"total_tokens": ctx.prompt_tokens + ctx.completion_tokens,
|
||||
}
|
||||
),
|
||||
]
|
||||
|
||||
def format_response(
|
||||
self, ctx: GenContext, content: str, stop: StopInfo
|
||||
) -> Dict[str, Any]:
|
||||
return {
|
||||
"id": self._resp_id,
|
||||
"object": "chat.completion",
|
||||
"created": ctx.created,
|
||||
"model": self._model,
|
||||
"choices": [
|
||||
{
|
||||
"index": 0,
|
||||
"message": {"role": "assistant", "content": content},
|
||||
"finish_reason": "stop",
|
||||
}
|
||||
],
|
||||
"usage": {
|
||||
"prompt_tokens": ctx.prompt_tokens,
|
||||
"completion_tokens": ctx.completion_tokens,
|
||||
"total_tokens": ctx.prompt_tokens + ctx.completion_tokens,
|
||||
},
|
||||
}
|
||||
|
|
@ -1,15 +1,13 @@
|
|||
"""Protocol handlers for OpenAI and Anthropic chat completion APIs.
|
||||
"""Orchestration layer: ProtocolHandler, StopChecker, GenContext, StopInfo, ResponseBuilder, SSE utils.
|
||||
|
||||
Template Method + Builder patterns eliminate the 45% code duplication between
|
||||
stream/non-stream branches and across protocol adapters.
|
||||
ProtocolHandler orchestrates the async generation loop and delegates
|
||||
protocol-specific formatting to a ResponseBuilder.
|
||||
"""
|
||||
|
||||
import json
|
||||
import time
|
||||
import uuid
|
||||
from abc import ABC, abstractmethod
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Dict, List, Optional, Union
|
||||
from typing import Any, AsyncGenerator, Dict, List, Optional, Tuple, Union
|
||||
|
||||
from fastapi.responses import StreamingResponse
|
||||
from pydantic import BaseModel
|
||||
|
|
@ -17,7 +15,7 @@ from pydantic import BaseModel
|
|||
from astrai.inference.engine import InferenceEngine
|
||||
|
||||
|
||||
def _sse_event(data: Dict[str, Any], event: Optional[str] = None) -> str:
|
||||
def sse_event(data: Dict[str, Any], event: Optional[str] = None) -> str:
|
||||
lines: List[str] = []
|
||||
if event:
|
||||
lines.append(f"event: {event}")
|
||||
|
|
@ -26,22 +24,28 @@ def _sse_event(data: Dict[str, Any], event: Optional[str] = None) -> str:
|
|||
return "\n".join(lines)
|
||||
|
||||
|
||||
def _sse_done() -> str:
|
||||
def sse_done() -> str:
|
||||
return "data: [DONE]\n\n"
|
||||
|
||||
|
||||
@dataclass
|
||||
class StreamContext:
|
||||
"""Shared state across the streaming generation lifecycle."""
|
||||
class GenContext:
|
||||
"""Per-generation metadata passed to builder format methods."""
|
||||
|
||||
resp_id: str
|
||||
created: int
|
||||
model: str
|
||||
prompt_tokens: int
|
||||
completion_tokens: int = 0
|
||||
accumulated: str = ""
|
||||
stop_matched: Optional[str] = None
|
||||
last_yield_trimmed: str = ""
|
||||
|
||||
|
||||
@dataclass
|
||||
class StopInfo:
|
||||
"""Stop-check result passed to format_stream_end / format_response."""
|
||||
|
||||
matched: Optional[str] = None
|
||||
body: str = ""
|
||||
yielded: str = ""
|
||||
|
||||
|
||||
class StopChecker:
|
||||
|
|
@ -56,95 +60,60 @@ class StopChecker:
|
|||
return seq
|
||||
return None
|
||||
|
||||
def trim(self, text: str, matched: str) -> str:
|
||||
idx = text.rfind(matched)
|
||||
return text[:idx] if idx != -1 else text
|
||||
|
||||
@property
|
||||
def has_sequences(self) -> bool:
|
||||
return len(self._sequences) > 0
|
||||
class ResponseBuilder(ABC):
|
||||
"""Interface for protocol-specific response formatting.
|
||||
|
||||
|
||||
class ProtocolHandler(ABC):
|
||||
"""Template-method base for API protocol handlers.
|
||||
|
||||
Subclasses implement format hooks; the base class orchestrates the
|
||||
generate-async loop and SSE/JSON response construction.
|
||||
|
||||
Lifecycle::
|
||||
|
||||
handle()
|
||||
├─ build_prompt() # protocol-specific prompt assembly
|
||||
├─ create_response_id() # unique response identifier
|
||||
├─ [stream]
|
||||
│ ├─ format_stream_start()
|
||||
│ ├─ format_stream_token() × N
|
||||
│ │ └─ on_token() hook for stop-sequence interception
|
||||
│ └─ format_stream_end()
|
||||
└─ [non-stream]
|
||||
├─ (accumulate tokens)
|
||||
└─ format_non_stream_response()
|
||||
A new protocol requires one concrete builder implementing 6 methods.
|
||||
"""
|
||||
|
||||
request_model: type[BaseModel]
|
||||
@abstractmethod
|
||||
def prepare(
|
||||
self, request: BaseModel, engine: InferenceEngine
|
||||
) -> Tuple[str, GenContext, List[str]]:
|
||||
"""Return (prompt, ctx, stop_sequences) for a generation request."""
|
||||
|
||||
def __init__(self, request: BaseModel, engine: InferenceEngine):
|
||||
@abstractmethod
|
||||
def format_stream_start(self, ctx: GenContext) -> List[str]:
|
||||
"""SSE events that open the stream."""
|
||||
|
||||
@abstractmethod
|
||||
def format_chunk(self, token: str) -> str:
|
||||
"""SSE event for a single generated token."""
|
||||
|
||||
@abstractmethod
|
||||
def format_stream_end(self, ctx: GenContext, stop: StopInfo) -> List[str]:
|
||||
"""SSE events that close the stream."""
|
||||
|
||||
@abstractmethod
|
||||
def format_response(
|
||||
self, ctx: GenContext, content: str, stop: StopInfo
|
||||
) -> Dict[str, Any]:
|
||||
"""JSON response body for non-streaming mode."""
|
||||
|
||||
|
||||
class ProtocolHandler:
|
||||
"""Orchestrates the generation loop, delegates formatting to a builder.
|
||||
|
||||
Usage::
|
||||
|
||||
handler = ProtocolHandler(request, engine, OpenAIResponseBuilder())
|
||||
response = await handler.handle()
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self, request: BaseModel, engine: InferenceEngine, builder: ResponseBuilder
|
||||
):
|
||||
self.request = request
|
||||
self.engine = engine
|
||||
|
||||
@abstractmethod
|
||||
def build_prompt(self) -> str:
|
||||
"""Build the full prompt string from the request messages."""
|
||||
|
||||
@abstractmethod
|
||||
def create_response_id(self) -> str:
|
||||
"""Generate a unique response ID following the protocol convention."""
|
||||
|
||||
@abstractmethod
|
||||
def format_stream_start(self, ctx: StreamContext) -> List[str]:
|
||||
"""Yield SSE events that open the stream (role marker, metadata)."""
|
||||
|
||||
@abstractmethod
|
||||
def format_stream_token(self, ctx: StreamContext, token: str) -> str:
|
||||
"""Yield an SSE event for a single generated token."""
|
||||
|
||||
@abstractmethod
|
||||
def format_stream_end(self, ctx: StreamContext) -> List[str]:
|
||||
"""Yield SSE events that close the stream (finish reason, usage stats)."""
|
||||
|
||||
@abstractmethod
|
||||
def format_non_stream_response(
|
||||
self, ctx: StreamContext, content: str
|
||||
) -> Dict[str, Any]:
|
||||
"""Build the JSON response body for non-streaming mode."""
|
||||
|
||||
def get_stop_sequences(self) -> List[str]:
|
||||
return []
|
||||
|
||||
def create_stop_checker(self) -> StopChecker:
|
||||
return StopChecker(self.get_stop_sequences())
|
||||
|
||||
def on_token(
|
||||
self, ctx: StreamContext, token: str, stop_checker: StopChecker
|
||||
) -> Optional[str]:
|
||||
"""Hook after each token is appended to accumulated.
|
||||
|
||||
Return a matched stop-sequence string to break the loop,
|
||||
or None to continue.
|
||||
|
||||
"""
|
||||
return None
|
||||
self.builder = builder
|
||||
|
||||
async def handle(self) -> Union[StreamingResponse, Dict[str, Any]]:
|
||||
ctx = StreamContext(
|
||||
resp_id=self.create_response_id(),
|
||||
created=int(time.time()),
|
||||
model=self.request.model,
|
||||
prompt_tokens=self._count_prompt_tokens(),
|
||||
)
|
||||
prompt, ctx, stop_sequences = self.builder.prepare(self.request, self.engine)
|
||||
ctx.prompt_tokens = len(self.engine.tokenizer.encode(prompt))
|
||||
|
||||
agen = self.engine.generate_async(
|
||||
prompt=self.build_prompt(),
|
||||
prompt=prompt,
|
||||
max_tokens=self.request.max_tokens,
|
||||
temperature=self.request.temperature,
|
||||
top_p=self.request.top_p,
|
||||
|
|
@ -152,33 +121,37 @@ class ProtocolHandler(ABC):
|
|||
)
|
||||
|
||||
if self.request.stream:
|
||||
return self._handle_stream(agen, ctx)
|
||||
return self._handle_stream(agen, ctx, stop_sequences)
|
||||
else:
|
||||
return await self._handle_non_stream(agen, ctx)
|
||||
return await self._handle_non_stream(agen, ctx, stop_sequences)
|
||||
|
||||
def _count_prompt_tokens(self) -> int:
|
||||
return len(self.engine.tokenizer.encode(self.build_prompt()))
|
||||
|
||||
def _handle_stream(self, agen, ctx: StreamContext) -> StreamingResponse:
|
||||
stop_checker = self.create_stop_checker()
|
||||
def _handle_stream(
|
||||
self, agen: AsyncGenerator, ctx: GenContext, stop_sequences: List[str]
|
||||
) -> StreamingResponse:
|
||||
checker = StopChecker(stop_sequences)
|
||||
|
||||
async def event_stream():
|
||||
for event in self.format_stream_start(ctx):
|
||||
for event in self.builder.format_stream_start(ctx):
|
||||
yield event
|
||||
|
||||
body = ""
|
||||
yielded = ""
|
||||
matched = None
|
||||
async for token in agen:
|
||||
ctx.completion_tokens += 1
|
||||
ctx.accumulated += token
|
||||
body += token
|
||||
|
||||
matched = self.on_token(ctx, token, stop_checker)
|
||||
matched = checker.check(body)
|
||||
if matched:
|
||||
break
|
||||
|
||||
yield self.format_stream_token(ctx, token)
|
||||
yield self.builder.format_chunk(token)
|
||||
yielded += token
|
||||
|
||||
for event in self.format_stream_end(ctx):
|
||||
stop = StopInfo(matched=matched, body=body, yielded=yielded)
|
||||
for event in self.builder.format_stream_end(ctx, stop):
|
||||
yield event
|
||||
yield _sse_done()
|
||||
yield sse_done()
|
||||
|
||||
return StreamingResponse(
|
||||
event_stream(),
|
||||
|
|
@ -186,260 +159,23 @@ class ProtocolHandler(ABC):
|
|||
headers={"Cache-Control": "no-cache", "Connection": "keep-alive"},
|
||||
)
|
||||
|
||||
async def _handle_non_stream(self, agen, ctx: StreamContext) -> Dict[str, Any]:
|
||||
stop_checker = self.create_stop_checker()
|
||||
async def _handle_non_stream(
|
||||
self, agen: AsyncGenerator, ctx: GenContext, stop_sequences: List[str]
|
||||
) -> Dict[str, Any]:
|
||||
checker = StopChecker(stop_sequences)
|
||||
chunks: List[str] = []
|
||||
body = ""
|
||||
matched = None
|
||||
|
||||
async for token in agen:
|
||||
ctx.completion_tokens += 1
|
||||
ctx.accumulated += token
|
||||
chunks.append(token)
|
||||
body += token
|
||||
|
||||
matched = self.on_token(ctx, token, stop_checker)
|
||||
matched = checker.check(body)
|
||||
if matched:
|
||||
break
|
||||
|
||||
content = "".join(chunks)
|
||||
return self.format_non_stream_response(ctx, content)
|
||||
|
||||
|
||||
def _extract_text_content(content: Union[str, List[Dict[str, Any]]]) -> str:
|
||||
"""Extract plain text from an Anthropic content block (string or list)."""
|
||||
if isinstance(content, str):
|
||||
return content
|
||||
if isinstance(content, list):
|
||||
for block in content:
|
||||
if isinstance(block, dict) and block.get("type") == "text":
|
||||
return block.get("text", "")
|
||||
return ""
|
||||
|
||||
|
||||
class OpenAIHandler(ProtocolHandler):
|
||||
"""OpenAI-compatible /v1/chat/completions handler."""
|
||||
|
||||
def build_prompt(self) -> str:
|
||||
messages = [
|
||||
{"role": m.role, "content": m.content} for m in self.request.messages
|
||||
]
|
||||
return self.engine.tokenizer.apply_chat_template(messages, tokenize=False)
|
||||
|
||||
def create_response_id(self) -> str:
|
||||
return f"chatcmpl-{uuid.uuid4().hex[:12]}"
|
||||
|
||||
def get_stop_sequences(self) -> List[str]:
|
||||
stop = self.request.stop
|
||||
if stop is None:
|
||||
return []
|
||||
return [stop] if isinstance(stop, str) else stop
|
||||
|
||||
def on_token(
|
||||
self, ctx: StreamContext, token: str, stop_checker: StopChecker
|
||||
) -> Optional[str]:
|
||||
return stop_checker.check(ctx.accumulated)
|
||||
|
||||
def format_stream_start(self, ctx: StreamContext) -> List[str]:
|
||||
return [
|
||||
_sse_event(
|
||||
{
|
||||
"id": ctx.resp_id,
|
||||
"object": "chat.completion.chunk",
|
||||
"created": ctx.created,
|
||||
"model": ctx.model,
|
||||
"choices": [
|
||||
{
|
||||
"index": 0,
|
||||
"delta": {"role": "assistant"},
|
||||
"finish_reason": None,
|
||||
}
|
||||
],
|
||||
}
|
||||
)
|
||||
]
|
||||
|
||||
def format_stream_token(self, ctx: StreamContext, token: str) -> str:
|
||||
return _sse_event(
|
||||
{
|
||||
"id": ctx.resp_id,
|
||||
"object": "chat.completion.chunk",
|
||||
"created": ctx.created,
|
||||
"model": ctx.model,
|
||||
"choices": [
|
||||
{"index": 0, "delta": {"content": token}, "finish_reason": None}
|
||||
],
|
||||
}
|
||||
)
|
||||
|
||||
def format_stream_end(self, ctx: StreamContext) -> List[str]:
|
||||
return [
|
||||
_sse_event(
|
||||
{
|
||||
"id": ctx.resp_id,
|
||||
"object": "chat.completion.chunk",
|
||||
"created": ctx.created,
|
||||
"model": ctx.model,
|
||||
"choices": [{"index": 0, "delta": {}, "finish_reason": "stop"}],
|
||||
}
|
||||
),
|
||||
_sse_event(
|
||||
{
|
||||
"prompt_tokens": ctx.prompt_tokens,
|
||||
"completion_tokens": ctx.completion_tokens,
|
||||
"total_tokens": ctx.prompt_tokens + ctx.completion_tokens,
|
||||
}
|
||||
),
|
||||
]
|
||||
|
||||
def format_non_stream_response(
|
||||
self, ctx: StreamContext, content: str
|
||||
) -> Dict[str, Any]:
|
||||
return {
|
||||
"id": ctx.resp_id,
|
||||
"object": "chat.completion",
|
||||
"created": ctx.created,
|
||||
"model": ctx.model,
|
||||
"choices": [
|
||||
{
|
||||
"index": 0,
|
||||
"message": {"role": "assistant", "content": content},
|
||||
"finish_reason": "stop",
|
||||
}
|
||||
],
|
||||
"usage": {
|
||||
"prompt_tokens": ctx.prompt_tokens,
|
||||
"completion_tokens": ctx.completion_tokens,
|
||||
"total_tokens": ctx.prompt_tokens + ctx.completion_tokens,
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
class AnthropicHandler(ProtocolHandler):
|
||||
"""Anthropic-compatible /v1/messages handler."""
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self._yielded = ""
|
||||
|
||||
def build_prompt(self) -> str:
|
||||
messages: List[Dict[str, str]] = []
|
||||
system = getattr(self.request, "system", None)
|
||||
if system:
|
||||
messages.append({"role": "system", "content": system})
|
||||
for m in self.request.messages:
|
||||
content = _extract_text_content(m.content)
|
||||
if content:
|
||||
messages.append({"role": m.role, "content": content})
|
||||
return self.engine.tokenizer.apply_chat_template(messages, tokenize=False)
|
||||
|
||||
def create_response_id(self) -> str:
|
||||
return f"msg_{uuid.uuid4().hex[:24]}"
|
||||
|
||||
def get_stop_sequences(self) -> List[str]:
|
||||
return getattr(self.request, "stop_sequences", None) or []
|
||||
|
||||
def on_token(
|
||||
self, ctx: StreamContext, token: str, stop_checker: StopChecker
|
||||
) -> Optional[str]:
|
||||
matched = stop_checker.check(ctx.accumulated)
|
||||
if not matched:
|
||||
return None
|
||||
|
||||
ctx.stop_matched = matched
|
||||
trimmed = ctx.accumulated[: ctx.accumulated.rfind(matched)]
|
||||
unyielded = trimmed[len(self._yielded) :]
|
||||
if unyielded:
|
||||
ctx.last_yield_trimmed = unyielded
|
||||
return matched
|
||||
|
||||
def format_stream_start(self, ctx: StreamContext) -> List[str]:
|
||||
return [
|
||||
_sse_event(
|
||||
{
|
||||
"type": "message_start",
|
||||
"message": {
|
||||
"id": ctx.resp_id,
|
||||
"type": "message",
|
||||
"role": "assistant",
|
||||
"model": ctx.model,
|
||||
"content": [],
|
||||
"usage": {"input_tokens": ctx.prompt_tokens},
|
||||
},
|
||||
},
|
||||
event="message_start",
|
||||
),
|
||||
_sse_event(
|
||||
{
|
||||
"type": "content_block_start",
|
||||
"index": 0,
|
||||
"content_block": {"type": "text", "text": ""},
|
||||
},
|
||||
event="content_block_start",
|
||||
),
|
||||
]
|
||||
|
||||
def format_stream_token(self, ctx: StreamContext, token: str) -> str:
|
||||
self._yielded += token
|
||||
return _sse_event(
|
||||
{
|
||||
"type": "content_block_delta",
|
||||
"index": 0,
|
||||
"delta": {"type": "text_delta", "text": token},
|
||||
},
|
||||
event="content_block_delta",
|
||||
)
|
||||
|
||||
def format_stream_end(self, ctx: StreamContext) -> List[str]:
|
||||
matched = ctx.stop_matched
|
||||
events: List[str] = []
|
||||
last_yielded = ctx.last_yield_trimmed
|
||||
if last_yielded:
|
||||
events.append(
|
||||
_sse_event(
|
||||
{
|
||||
"type": "content_block_delta",
|
||||
"index": 0,
|
||||
"delta": {"type": "text_delta", "text": last_yielded},
|
||||
},
|
||||
event="content_block_delta",
|
||||
)
|
||||
)
|
||||
events.append(
|
||||
_sse_event(
|
||||
{"type": "content_block_stop", "index": 0},
|
||||
event="content_block_stop",
|
||||
)
|
||||
)
|
||||
events.append(
|
||||
_sse_event(
|
||||
{
|
||||
"type": "message_delta",
|
||||
"delta": {
|
||||
"stop_reason": "stop_sequence" if matched else "end_turn",
|
||||
"stop_sequence": matched,
|
||||
},
|
||||
"usage": {"output_tokens": ctx.completion_tokens},
|
||||
},
|
||||
event="message_delta",
|
||||
)
|
||||
)
|
||||
events.append(_sse_event({"type": "message_stop"}, event="message_stop"))
|
||||
return events
|
||||
|
||||
def format_non_stream_response(
|
||||
self, ctx: StreamContext, content: str
|
||||
) -> Dict[str, Any]:
|
||||
matched = ctx.stop_matched
|
||||
if matched:
|
||||
content = content[: content.rfind(matched)]
|
||||
return {
|
||||
"id": ctx.resp_id,
|
||||
"type": "message",
|
||||
"role": "assistant",
|
||||
"model": ctx.model,
|
||||
"content": [{"type": "text", "text": content}],
|
||||
"stop_reason": "stop_sequence" if matched else "end_turn",
|
||||
"stop_sequence": matched,
|
||||
"usage": {
|
||||
"input_tokens": ctx.prompt_tokens,
|
||||
"output_tokens": ctx.completion_tokens,
|
||||
},
|
||||
}
|
||||
stop = StopInfo(matched=matched, body=body)
|
||||
return self.builder.format_response(ctx, content, stop)
|
||||
|
|
|
|||
|
|
@ -15,7 +15,9 @@ import uvicorn
|
|||
from fastapi import FastAPI, HTTPException
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from astrai.inference.api.protocol import AnthropicHandler, OpenAIHandler
|
||||
from astrai.inference.api.anthropic import AnthropicResponseBuilder
|
||||
from astrai.inference.api.openai import OpenAIResponseBuilder
|
||||
from astrai.inference.api.protocol import ProtocolHandler
|
||||
from astrai.inference.engine import InferenceEngine
|
||||
from astrai.model import AutoModel
|
||||
from astrai.tokenize import AutoTokenizer
|
||||
|
|
@ -133,14 +135,14 @@ async def get_stats():
|
|||
@app.post("/v1/chat/completions")
|
||||
async def chat_completion(request: ChatCompletionRequest):
|
||||
engine = _get_engine()
|
||||
handler = OpenAIHandler(request, engine)
|
||||
handler = ProtocolHandler(request, engine, OpenAIResponseBuilder())
|
||||
return await handler.handle()
|
||||
|
||||
|
||||
@app.post("/v1/messages")
|
||||
async def create_message(request: MessagesRequest):
|
||||
engine = _get_engine()
|
||||
handler = AnthropicHandler(request, engine)
|
||||
handler = ProtocolHandler(request, engine, AnthropicResponseBuilder())
|
||||
return await handler.handle()
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -108,7 +108,10 @@ class InferenceScheduler:
|
|||
continue
|
||||
|
||||
to_prefill = [
|
||||
t for t in self._task_mgr.get_active_tasks() if t.output_tokens == 0
|
||||
t
|
||||
for t in self._task_mgr.get_active_tasks()
|
||||
if t.output_tokens == 0
|
||||
and self._page_cache.task_cached(t.task_id) < len(t.prompt_ids)
|
||||
]
|
||||
if to_prefill:
|
||||
for t in to_prefill:
|
||||
|
|
|
|||
|
|
@ -13,17 +13,6 @@ from astrai.inference.core.task import STOP
|
|||
from astrai.tokenize import AutoTokenizer
|
||||
|
||||
|
||||
def _validate_sampling_params(
|
||||
top_k: int, top_p: float, temperature: float, max_tokens: Optional[int] = None
|
||||
):
|
||||
if not (isinstance(top_k, int) and top_k >= 0):
|
||||
raise ValueError("top_k must be a non-negative integer")
|
||||
if not (0.0 <= top_p <= 1.0):
|
||||
raise ValueError("top_p must be a float between 0.0 and 1.0")
|
||||
if not (isinstance(temperature, (int, float)) and temperature >= 0):
|
||||
raise ValueError("temperature must be a non-negative number")
|
||||
|
||||
|
||||
class GenerateResult:
|
||||
"""Thread-safe token accumulator for streaming and non-streaming modes."""
|
||||
|
||||
|
|
@ -86,7 +75,12 @@ class GenerationRequest:
|
|||
max_tokens: Optional[int] = None,
|
||||
stream: bool = False,
|
||||
):
|
||||
_validate_sampling_params(top_k, top_p, temperature, max_tokens)
|
||||
if not (isinstance(top_k, int) and top_k >= 0):
|
||||
raise ValueError("top_k must be a non-negative integer")
|
||||
if not (0.0 <= top_p <= 1.0):
|
||||
raise ValueError("top_p must be a float between 0.0 and 1.0")
|
||||
if not (isinstance(temperature, (int, float)) and temperature >= 0):
|
||||
raise ValueError("temperature must be a non-negative number")
|
||||
|
||||
self.messages = messages
|
||||
self.top_k = top_k
|
||||
|
|
@ -137,7 +131,6 @@ class InferenceEngine:
|
|||
top_p: float = 1.0,
|
||||
top_k: int = 50,
|
||||
) -> Union[Generator, str, List[str]]:
|
||||
_validate_sampling_params(top_k, top_p, temperature, max_tokens)
|
||||
is_batch = isinstance(prompt, list)
|
||||
prompts = prompt if is_batch else [prompt]
|
||||
|
||||
|
|
@ -158,7 +151,6 @@ class InferenceEngine:
|
|||
top_p: float = 1.0,
|
||||
top_k: int = 50,
|
||||
) -> AsyncGenerator[str, None]:
|
||||
_validate_sampling_params(top_k, top_p, temperature, max_tokens)
|
||||
sync_gen = self._generate_streaming(
|
||||
[prompt], False, max_tokens, temperature, top_p, top_k
|
||||
)
|
||||
|
|
|
|||
|
|
@ -2,16 +2,15 @@
|
|||
AutoModel base class for model loading and saving.
|
||||
"""
|
||||
|
||||
import json
|
||||
from contextlib import contextmanager
|
||||
from pathlib import Path
|
||||
from typing import Self, Union
|
||||
|
||||
import safetensors.torch as st
|
||||
import torch.nn as nn
|
||||
|
||||
from astrai.config.model_config import BaseModelConfig, ConfigFactory
|
||||
from astrai.factory import BaseFactory
|
||||
from astrai.serialization import load_model_config, load_model_weights, save_model
|
||||
|
||||
|
||||
@contextmanager
|
||||
|
|
@ -60,25 +59,22 @@ class AutoModel(BaseFactory["AutoModel"], nn.Module):
|
|||
|
||||
model_path = Path(path)
|
||||
|
||||
# Load config
|
||||
config_path = model_path / "config.json"
|
||||
if config_path.exists():
|
||||
with open(config_path, "r") as f:
|
||||
raw = json.load(f)
|
||||
config = ConfigFactory.load(raw)
|
||||
model_type = config.model_type or "autoregressive_lm"
|
||||
else:
|
||||
if not config_path.exists():
|
||||
raise FileNotFoundError(f"Config file not found: {config_path}")
|
||||
|
||||
raw = load_model_config(str(model_path))
|
||||
config = ConfigFactory.load(raw)
|
||||
model_type = config.model_type or "autoregressive_lm"
|
||||
|
||||
actual_cls = AutoModel.get_component_class(model_type)
|
||||
|
||||
with _disable_random_init(enable=disable_random_init):
|
||||
model = actual_cls(config)
|
||||
|
||||
# Load weights
|
||||
weights_path = model_path / "model.safetensors"
|
||||
if weights_path.exists():
|
||||
state_dict = st.load_file(str(weights_path))
|
||||
state_dict = load_model_weights(str(model_path))
|
||||
model.load_state_dict(state_dict, strict=strict)
|
||||
|
||||
return model
|
||||
|
|
@ -87,14 +83,11 @@ class AutoModel(BaseFactory["AutoModel"], nn.Module):
|
|||
self,
|
||||
save_directory: Union[str, Path],
|
||||
) -> None:
|
||||
save_path = Path(save_directory)
|
||||
save_path.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# Save config
|
||||
self.config.to_file(str(save_path / "config.json"))
|
||||
|
||||
# Save weights
|
||||
st.save_file(self.state_dict(), str(save_path / "model.safetensors"))
|
||||
save_model(
|
||||
config=self.config.to_dict(),
|
||||
state_dict=self.state_dict(),
|
||||
save_directory=str(save_directory),
|
||||
)
|
||||
|
||||
def to(self, *args, **kwargs) -> Self:
|
||||
"""Move model to device/dtype."""
|
||||
|
|
|
|||
|
|
@ -1,15 +1,19 @@
|
|||
import json
|
||||
import logging
|
||||
from dataclasses import asdict, dataclass
|
||||
from pathlib import Path
|
||||
from typing import Optional, Set
|
||||
|
||||
import safetensors.torch as st
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
from astrai.model.components.linear import Linear
|
||||
from astrai.serialization import (
|
||||
load_json,
|
||||
load_safetensors,
|
||||
save_json,
|
||||
save_safetensors,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
|
@ -128,16 +132,14 @@ def save_lora(model: nn.Module, save_dir: str, config: LoRAConfig):
|
|||
|
||||
path = Path(save_dir)
|
||||
path.mkdir(parents=True, exist_ok=True)
|
||||
st.save_file(lora_sd, str(path / "adapter_model.safetensors"))
|
||||
with open(path / "adapter_config.json", "w") as f:
|
||||
json.dump(asdict(config), f, indent=2)
|
||||
save_safetensors(lora_sd, path / "adapter_model.safetensors")
|
||||
save_json(asdict(config), path / "adapter_config.json")
|
||||
logger.info("LoRA adapter saved to %s (%d keys)", save_dir, len(lora_sd))
|
||||
|
||||
|
||||
def load_lora(model: nn.Module, load_dir: str) -> LoRAConfig:
|
||||
path = Path(load_dir)
|
||||
with open(path / "adapter_config.json") as f:
|
||||
raw = json.load(f)
|
||||
raw = load_json(path / "adapter_config.json")
|
||||
config = LoRAConfig(
|
||||
r=raw["r"], alpha=raw["alpha"], target_modules=tuple(raw["target_modules"])
|
||||
)
|
||||
|
|
@ -157,7 +159,7 @@ def load_lora(model: nn.Module, load_dir: str) -> LoRAConfig:
|
|||
target_modules=set(config.target_modules),
|
||||
)
|
||||
|
||||
weights = st.load_file(str(path / "adapter_model.safetensors"))
|
||||
weights = load_safetensors(path / "adapter_model.safetensors")
|
||||
try:
|
||||
missing, unexpected = model.load_state_dict(weights, strict=False)
|
||||
except RuntimeError as e:
|
||||
|
|
|
|||
|
|
@ -1,7 +1,8 @@
|
|||
import json
|
||||
import time
|
||||
from dataclasses import dataclass, field
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, Optional
|
||||
from typing import Any, Dict
|
||||
|
||||
import safetensors.torch as st
|
||||
import torch
|
||||
|
|
@ -9,75 +10,101 @@ import torch.distributed as dist
|
|||
|
||||
from astrai.parallel.setup import get_rank
|
||||
|
||||
_META_FILE = "meta.json"
|
||||
_WEIGHTS_FILE = "model.safetensors"
|
||||
_MODEL_CONFIG_FILE = "config.json"
|
||||
|
||||
|
||||
def save_safetensors(state_dict: dict, path: str | Path) -> None:
|
||||
st.save_file(state_dict, str(path))
|
||||
|
||||
|
||||
def load_safetensors(path: str | Path) -> dict:
|
||||
return st.load_file(str(path))
|
||||
|
||||
|
||||
def save_json(data: dict, path: str | Path) -> None:
|
||||
with open(str(path), "w") as f:
|
||||
json.dump(data, f, indent=2)
|
||||
|
||||
|
||||
def load_json(path: str | Path) -> dict:
|
||||
with open(str(path), "r") as f:
|
||||
return json.load(f)
|
||||
|
||||
|
||||
def save_torch(obj: Any, path: str | Path) -> None:
|
||||
torch.save(obj, str(path))
|
||||
|
||||
|
||||
def load_torch(path: str | Path) -> Any:
|
||||
return torch.load(str(path), map_location="cpu", weights_only=False)
|
||||
|
||||
|
||||
@dataclass
|
||||
class Checkpoint:
|
||||
def __init__(
|
||||
self,
|
||||
state_dict: Dict[str, Any],
|
||||
epoch: int = 0,
|
||||
iteration: int = 0,
|
||||
extra: Optional[Dict[str, Any]] = None,
|
||||
meta: Optional[Dict[str, Any]] = None,
|
||||
):
|
||||
self.state_dict = state_dict
|
||||
self.epoch = epoch
|
||||
self.iteration = iteration
|
||||
self.extra = extra or {}
|
||||
self.meta = meta or {}
|
||||
|
||||
def save(
|
||||
self,
|
||||
save_dir: str,
|
||||
) -> None:
|
||||
state_dict: Dict[str, Any] = field(default_factory=dict)
|
||||
epoch: int = 0
|
||||
iteration: int = 0
|
||||
extra: Dict[str, Any] = field(default_factory=dict)
|
||||
meta: Dict[str, Any] = field(default_factory=dict)
|
||||
|
||||
def save(self, save_dir: str) -> None:
|
||||
save_path = Path(save_dir)
|
||||
save_path.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
rank = get_rank()
|
||||
if rank == 0:
|
||||
meta = {
|
||||
"epoch": self.epoch,
|
||||
"iteration": self.iteration,
|
||||
"timestamp": time.strftime("%Y-%m-%dT%H:%M:%S"),
|
||||
}
|
||||
meta.update(self.meta)
|
||||
with open(save_path / "meta.json", "w") as f:
|
||||
json.dump(meta, f, indent=2)
|
||||
if get_rank() != 0:
|
||||
return
|
||||
|
||||
st.save_file(self.state_dict, save_path / "state_dict.safetensors")
|
||||
if self.extra:
|
||||
for key, value in self.extra.items():
|
||||
torch.save(value, save_path / f"{key}.pt")
|
||||
meta = {
|
||||
"epoch": self.epoch,
|
||||
"iteration": self.iteration,
|
||||
"timestamp": time.strftime("%Y-%m-%dT%H:%M:%S"),
|
||||
**self.meta,
|
||||
}
|
||||
save_json(meta, save_path / _META_FILE)
|
||||
save_safetensors(self.state_dict, save_path / _WEIGHTS_FILE)
|
||||
for key, value in self.extra.items():
|
||||
save_torch(value, save_path / f"{key}.pt")
|
||||
|
||||
@classmethod
|
||||
def load(
|
||||
cls,
|
||||
save_dir: str,
|
||||
) -> "Checkpoint":
|
||||
|
||||
rank = get_rank()
|
||||
def load(cls, save_dir: str) -> "Checkpoint":
|
||||
save_path = Path(save_dir)
|
||||
|
||||
meta = {}
|
||||
if rank == 0:
|
||||
with open(Path(save_dir) / "meta.json", "r") as f:
|
||||
meta = json.load(f)
|
||||
if get_rank() == 0:
|
||||
meta = load_json(save_path / _META_FILE)
|
||||
|
||||
if dist.is_initialized():
|
||||
meta_list = [meta]
|
||||
dist.broadcast_object_list(meta_list, src=0)
|
||||
meta = meta_list[0]
|
||||
|
||||
state_dict = st.load_file(save_path / "state_dict.safetensors")
|
||||
state_dict = load_safetensors(save_path / _WEIGHTS_FILE)
|
||||
|
||||
extra = {}
|
||||
for f in save_path.iterdir():
|
||||
if f.suffix == ".pt" and f.stem not in ("meta",):
|
||||
extra[f.stem] = torch.load(f, map_location="cpu", weights_only=False)
|
||||
if f.suffix == ".pt":
|
||||
extra[f.stem] = load_torch(f)
|
||||
|
||||
return cls(
|
||||
state_dict=state_dict,
|
||||
epoch=meta["epoch"],
|
||||
iteration=meta["iteration"],
|
||||
extra=extra or None,
|
||||
epoch=meta.get("epoch", 0),
|
||||
iteration=meta.get("iteration", 0),
|
||||
extra=extra,
|
||||
)
|
||||
|
||||
|
||||
def save_model(config: dict, state_dict: dict, save_directory: str) -> None:
|
||||
save_path = Path(save_directory)
|
||||
save_path.mkdir(parents=True, exist_ok=True)
|
||||
save_json(config, save_path / _MODEL_CONFIG_FILE)
|
||||
save_safetensors(state_dict, save_path / _WEIGHTS_FILE)
|
||||
|
||||
|
||||
def load_model_config(save_directory: str) -> dict:
|
||||
return load_json(Path(save_directory) / _MODEL_CONFIG_FILE)
|
||||
|
||||
|
||||
def load_model_weights(save_directory: str) -> dict:
|
||||
return load_safetensors(Path(save_directory) / _WEIGHTS_FILE)
|
||||
|
|
|
|||
|
|
@ -210,7 +210,7 @@ class ProgressBarCallback(TrainCallback):
|
|||
"""
|
||||
|
||||
def __init__(
|
||||
self, num_epoch: int, log_interval: int = 100, file: IO[str] = sys.stdout
|
||||
self, num_epoch: int, log_interval: int = 100, file: Optional[IO[str]] = None
|
||||
):
|
||||
self.num_epoch = num_epoch
|
||||
self.log_interval = log_interval
|
||||
|
|
@ -223,7 +223,7 @@ class ProgressBarCallback(TrainCallback):
|
|||
context.dataloader,
|
||||
desc=f"Epoch {context.epoch + 1}/{self.num_epoch}",
|
||||
dynamic_ncols=True,
|
||||
file=self.file,
|
||||
file=self.file or sys.stdout,
|
||||
)
|
||||
|
||||
@only_on_rank(0)
|
||||
|
|
|
|||
|
|
@ -71,7 +71,8 @@ class TrainContextBuilder:
|
|||
if self._checkpoint is not None:
|
||||
context.epoch = max(self._checkpoint.epoch, cfg.start_epoch)
|
||||
context.iteration = max(self._checkpoint.iteration, cfg.start_batch)
|
||||
context.model.load_state_dict(self._checkpoint.state_dict)
|
||||
if self._checkpoint.state_dict:
|
||||
context.model.load_state_dict(self._checkpoint.state_dict)
|
||||
context.checkpoint = self._checkpoint
|
||||
else:
|
||||
context.checkpoint = Checkpoint(
|
||||
|
|
|
|||
|
|
@ -2,13 +2,13 @@ import argparse
|
|||
import os
|
||||
from functools import partial
|
||||
|
||||
import safetensors.torch as st
|
||||
import torch
|
||||
import torch.optim as optim
|
||||
|
||||
from astrai.config import AutoRegressiveLMConfig, TrainConfig
|
||||
from astrai.dataset import DatasetFactory
|
||||
from astrai.model import AutoRegressiveLM
|
||||
from astrai.serialization import Checkpoint
|
||||
from astrai.trainer import SchedulerFactory, Trainer
|
||||
|
||||
|
||||
|
|
@ -236,16 +236,14 @@ def train(
|
|||
if window_size is None:
|
||||
window_size = config.max_len
|
||||
|
||||
# Create bare AutoRegressiveLM (for training, no tokenizer needed)
|
||||
model = AutoRegressiveLM(config)
|
||||
# Create model and load full checkpoint (state_dict + optimizer + scheduler + meta)
|
||||
checkpoint = Checkpoint.load(param_path)
|
||||
model = AutoRegressiveLM(config).to(dtype=torch.bfloat16)
|
||||
model.load_state_dict(checkpoint.state_dict, strict=False)
|
||||
|
||||
# Load weights if available
|
||||
weights_path = os.path.join(param_path, "model.safetensors")
|
||||
if os.path.exists(weights_path):
|
||||
state_dict = st.load_file(weights_path)
|
||||
model.load_state_dict(state_dict, strict=False)
|
||||
|
||||
model = model.to(dtype=torch.bfloat16)
|
||||
# Strip state_dict to avoid pickling ~7GB through mp.spawn pipe
|
||||
# (model weights already loaded into model above)
|
||||
checkpoint.state_dict = {}
|
||||
|
||||
strategy_kwargs = {
|
||||
"beta": dpo_beta,
|
||||
|
|
@ -257,8 +255,6 @@ def train(
|
|||
}
|
||||
|
||||
executor_kwargs = {
|
||||
"static_graph": True,
|
||||
"find_unused_parameters": False,
|
||||
"gradient_as_bucket_view": True,
|
||||
"broadcast_buffers": False,
|
||||
}
|
||||
|
|
@ -319,7 +315,7 @@ def train(
|
|||
)
|
||||
|
||||
trainer = Trainer(train_config)
|
||||
trainer.train()
|
||||
trainer.train(checkpoint=checkpoint)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
|
|
|||
|
|
@ -1,3 +1,4 @@
|
|||
import os
|
||||
import tempfile
|
||||
|
||||
import torch
|
||||
|
|
@ -36,7 +37,6 @@ def test_single_process():
|
|||
|
||||
|
||||
def test_checkpoint_with_extra():
|
||||
"""Verify extra keys are saved as individual .pt files and loaded back."""
|
||||
model = torch.nn.Linear(10, 5)
|
||||
optimizer = AdamW(model.parameters(), lr=1e-3)
|
||||
optimizer.step()
|
||||
|
|
@ -52,8 +52,6 @@ def test_checkpoint_with_extra():
|
|||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
checkpoint.save(tmpdir)
|
||||
|
||||
import os
|
||||
|
||||
assert os.path.exists(os.path.join(tmpdir, "optimizer.pt"))
|
||||
assert os.path.exists(os.path.join(tmpdir, "scheduler.pt"))
|
||||
|
||||
|
|
|
|||
|
|
@ -0,0 +1,286 @@
|
|||
"""Unit tests for protocol builders, StopChecker, GenContext, StopInfo."""
|
||||
|
||||
import json
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import pytest
|
||||
|
||||
from astrai.inference.api.anthropic import AnthropicResponseBuilder
|
||||
from astrai.inference.api.openai import OpenAIResponseBuilder
|
||||
from astrai.inference.api.protocol import GenContext, StopChecker, StopInfo
|
||||
from astrai.inference.engine import GenerationRequest
|
||||
|
||||
|
||||
def _make_ctx(**kwargs):
|
||||
defaults = {
|
||||
"resp_id": "test-123",
|
||||
"created": 1000,
|
||||
"model": "test-model",
|
||||
"prompt_tokens": 10,
|
||||
"completion_tokens": 5,
|
||||
}
|
||||
defaults.update(kwargs)
|
||||
return GenContext(**defaults)
|
||||
|
||||
|
||||
def _sse_payloads(events):
|
||||
payloads = []
|
||||
for chunk in events:
|
||||
for line in chunk.strip().split("\n"):
|
||||
if line.startswith("data: "):
|
||||
try:
|
||||
payloads.append(json.loads(line[6:]))
|
||||
except json.JSONDecodeError:
|
||||
pass
|
||||
return payloads
|
||||
|
||||
|
||||
class TestStopChecker:
|
||||
def test_check_finds_match(self):
|
||||
sc = StopChecker(["stop", "end"])
|
||||
assert sc.check("hello stop world") == "stop"
|
||||
|
||||
def test_check_returns_none_when_no_match(self):
|
||||
sc = StopChecker(["stop"])
|
||||
assert sc.check("hello world") is None
|
||||
|
||||
def test_check_empty_sequences(self):
|
||||
sc = StopChecker([])
|
||||
assert sc.check("hello") is None
|
||||
|
||||
|
||||
class TestGenContext:
|
||||
def test_defaults(self):
|
||||
ctx = GenContext(resp_id="a", created=1, model="m", prompt_tokens=10)
|
||||
assert ctx.completion_tokens == 0
|
||||
|
||||
def test_fields_mutable(self):
|
||||
ctx = GenContext(resp_id="a", created=1, model="m", prompt_tokens=10)
|
||||
ctx.completion_tokens = 42
|
||||
assert ctx.completion_tokens == 42
|
||||
|
||||
|
||||
class TestStopInfo:
|
||||
def test_defaults(self):
|
||||
s = StopInfo()
|
||||
assert s.matched is None
|
||||
assert s.body == ""
|
||||
assert s.yielded == ""
|
||||
|
||||
def test_with_values(self):
|
||||
s = StopInfo(matched="stop", body="hello stop", yielded="hello ")
|
||||
assert s.matched == "stop"
|
||||
assert s.body == "hello stop"
|
||||
assert s.yielded == "hello "
|
||||
|
||||
|
||||
class TestOpenAIResponseBuilder:
|
||||
@pytest.fixture
|
||||
def builder(self):
|
||||
builder = OpenAIResponseBuilder()
|
||||
req = MagicMock()
|
||||
req.messages = [MagicMock(role="user", content="Hello")]
|
||||
req.stop = None
|
||||
req.model = "astrai"
|
||||
engine = MagicMock()
|
||||
engine.tokenizer.apply_chat_template.return_value = "Hello"
|
||||
builder.prepare(req, engine)
|
||||
return builder
|
||||
|
||||
def test_prepare_returns_prompt_ctx_stops(self, builder):
|
||||
req = MagicMock()
|
||||
req.messages = [MagicMock(role="user", content="Hi")]
|
||||
req.stop = ["END"]
|
||||
req.model = "gpt"
|
||||
engine = MagicMock()
|
||||
engine.tokenizer.apply_chat_template.return_value = "Hi"
|
||||
prompt, ctx, stops = builder.prepare(req, engine)
|
||||
assert prompt == "Hi"
|
||||
assert ctx.model == "gpt"
|
||||
assert ctx.prompt_tokens == 0
|
||||
assert stops == ["END"]
|
||||
|
||||
def test_prepare_no_stop_returns_empty_list(self, builder):
|
||||
req = MagicMock()
|
||||
req.messages = []
|
||||
req.stop = None
|
||||
req.model = "x"
|
||||
engine = MagicMock()
|
||||
engine.tokenizer.apply_chat_template.return_value = ""
|
||||
_, _, stops = builder.prepare(req, engine)
|
||||
assert stops == []
|
||||
|
||||
def test_format_stream_start(self, builder):
|
||||
ctx = _make_ctx()
|
||||
events = builder.format_stream_start(ctx)
|
||||
payloads = _sse_payloads(events)
|
||||
assert len(payloads) == 1
|
||||
p = payloads[0]
|
||||
assert p["object"] == "chat.completion.chunk"
|
||||
assert p["choices"][0]["delta"]["role"] == "assistant"
|
||||
assert p["choices"][0]["finish_reason"] is None
|
||||
|
||||
def test_format_chunk(self, builder):
|
||||
event = builder.format_chunk("hello")
|
||||
payload = json.loads(event.split("data: ", 1)[1])
|
||||
assert payload["choices"][0]["delta"]["content"] == "hello"
|
||||
assert payload["choices"][0]["finish_reason"] is None
|
||||
|
||||
def test_format_stream_end(self, builder):
|
||||
ctx = _make_ctx(completion_tokens=5)
|
||||
stop = StopInfo(matched="stop")
|
||||
events = builder.format_stream_end(ctx, stop)
|
||||
payloads = _sse_payloads(events)
|
||||
finish = payloads[0]
|
||||
assert finish["choices"][0]["finish_reason"] == "stop"
|
||||
usage = payloads[1]
|
||||
assert usage["completion_tokens"] == 5
|
||||
assert usage["total_tokens"] == 15
|
||||
|
||||
def test_format_response(self, builder):
|
||||
ctx = _make_ctx()
|
||||
stop = StopInfo()
|
||||
resp = builder.format_response(ctx, "hello", stop)
|
||||
assert resp["object"] == "chat.completion"
|
||||
assert resp["choices"][0]["message"]["content"] == "hello"
|
||||
assert resp["usage"]["prompt_tokens"] == 10
|
||||
|
||||
|
||||
class TestAnthropicResponseBuilder:
|
||||
@pytest.fixture
|
||||
def builder(self):
|
||||
builder = AnthropicResponseBuilder()
|
||||
req = MagicMock()
|
||||
req.messages = [MagicMock(role="user", content="Hello")]
|
||||
req.model = "claude"
|
||||
engine = MagicMock()
|
||||
engine.tokenizer.apply_chat_template.return_value = "Hello"
|
||||
req.system = None
|
||||
builder.prepare(req, engine)
|
||||
return builder
|
||||
|
||||
def test_prepare_messages(self, builder):
|
||||
req = MagicMock()
|
||||
req.messages = [MagicMock(role="user", content="Hi")]
|
||||
req.model = "claude"
|
||||
req.system = None
|
||||
req.stop_sequences = None
|
||||
engine = MagicMock()
|
||||
engine.tokenizer.apply_chat_template.return_value = "Hi"
|
||||
prompt, ctx, stops = builder.prepare(req, engine)
|
||||
assert prompt == "Hi"
|
||||
assert stops == []
|
||||
|
||||
def test_prepare_with_stop_sequences(self, builder):
|
||||
req = MagicMock()
|
||||
req.messages = []
|
||||
req.model = "x"
|
||||
req.stop_sequences = ["stop", "end"]
|
||||
req.system = None
|
||||
engine = MagicMock()
|
||||
engine.tokenizer.apply_chat_template.return_value = ""
|
||||
_, _, stops = builder.prepare(req, engine)
|
||||
assert stops == ["stop", "end"]
|
||||
|
||||
def test_format_stream_start(self, builder):
|
||||
ctx = _make_ctx(prompt_tokens=3)
|
||||
events = builder.format_stream_start(ctx)
|
||||
payloads = _sse_payloads(events)
|
||||
assert len(payloads) == 2
|
||||
assert payloads[0]["type"] == "message_start"
|
||||
assert payloads[0]["message"]["usage"]["input_tokens"] == 3
|
||||
assert payloads[1]["type"] == "content_block_start"
|
||||
|
||||
def test_format_chunk(self, builder):
|
||||
event = builder.format_chunk("tok")
|
||||
payload = json.loads(event.split("data: ", 1)[1])
|
||||
assert payload["type"] == "content_block_delta"
|
||||
assert payload["delta"]["text"] == "tok"
|
||||
|
||||
def test_format_stream_end_no_stop(self, builder):
|
||||
ctx = _make_ctx(completion_tokens=3)
|
||||
stop = StopInfo()
|
||||
events = builder.format_stream_end(ctx, stop)
|
||||
payloads = _sse_payloads(events)
|
||||
# content_block_stop, message_delta, message_stop
|
||||
types = [p["type"] for p in payloads]
|
||||
assert types == ["content_block_stop", "message_delta", "message_stop"]
|
||||
assert payloads[1]["delta"]["stop_reason"] == "end_turn"
|
||||
|
||||
def test_format_stream_end_with_stop_trims_and_emits_remaining(self, builder):
|
||||
ctx = _make_ctx(completion_tokens=7)
|
||||
stop = StopInfo(
|
||||
matched="END",
|
||||
body="Hello world END extra",
|
||||
yielded="Hello ",
|
||||
)
|
||||
events = builder.format_stream_end(ctx, stop)
|
||||
payloads = _sse_payloads(events)
|
||||
# unyielded delta, content_block_stop, message_delta, message_stop
|
||||
types = [p["type"] for p in payloads]
|
||||
assert types == [
|
||||
"content_block_delta",
|
||||
"content_block_stop",
|
||||
"message_delta",
|
||||
"message_stop",
|
||||
]
|
||||
assert payloads[0]["delta"]["text"] == "world "
|
||||
assert payloads[2]["delta"]["stop_reason"] == "stop_sequence"
|
||||
assert payloads[2]["delta"]["stop_sequence"] == "END"
|
||||
|
||||
def test_format_stream_end_stop_trimmed_already_yielded(self, builder):
|
||||
ctx = _make_ctx()
|
||||
stop = StopInfo(
|
||||
matched="END",
|
||||
body="Hello END",
|
||||
yielded="Hello ",
|
||||
)
|
||||
events = builder.format_stream_end(ctx, stop)
|
||||
payloads = _sse_payloads(events)
|
||||
# No unyielded delta (everything already sent)
|
||||
types = [p["type"] for p in payloads]
|
||||
assert types == ["content_block_stop", "message_delta", "message_stop"]
|
||||
|
||||
def test_format_response_with_stop_trims_content(self, builder):
|
||||
ctx = _make_ctx()
|
||||
stop = StopInfo(matched="STOP", body="text STOP extra", yielded="text ")
|
||||
resp = builder.format_response(ctx, "text STOP extra", stop)
|
||||
assert resp["content"][0]["text"] == "text "
|
||||
assert resp["stop_reason"] == "stop_sequence"
|
||||
assert resp["stop_sequence"] == "STOP"
|
||||
|
||||
def test_format_response_no_stop(self, builder):
|
||||
ctx = _make_ctx()
|
||||
stop = StopInfo()
|
||||
resp = builder.format_response(ctx, "full text", stop)
|
||||
assert resp["content"][0]["text"] == "full text"
|
||||
assert resp["stop_reason"] == "end_turn"
|
||||
|
||||
|
||||
class TestGenerationRequestValidation:
|
||||
def test_valid_params(self):
|
||||
gr = GenerationRequest(
|
||||
messages=[{"role": "user", "content": "hi"}],
|
||||
top_k=50,
|
||||
top_p=0.9,
|
||||
temperature=0.7,
|
||||
)
|
||||
assert gr.top_k == 50
|
||||
|
||||
def test_invalid_top_p_raises(self):
|
||||
with pytest.raises(ValueError, match="top_p"):
|
||||
GenerationRequest(messages=[{"role": "user", "content": "hi"}], top_p=1.5)
|
||||
|
||||
def test_invalid_top_k_raises(self):
|
||||
with pytest.raises(ValueError, match="top_k"):
|
||||
GenerationRequest(messages=[{"role": "user", "content": "hi"}], top_k=-1)
|
||||
|
||||
def test_invalid_temperature_raises(self):
|
||||
with pytest.raises(ValueError, match="temperature"):
|
||||
GenerationRequest(
|
||||
messages=[{"role": "user", "content": "hi"}], temperature=-0.1
|
||||
)
|
||||
|
||||
def test_top_k_zero_valid(self):
|
||||
gr = GenerationRequest(messages=[{"role": "user", "content": "hi"}], top_k=0)
|
||||
assert gr.top_k == 0
|
||||
|
|
@ -173,3 +173,21 @@ def test_scheduler_concurrent_get_stats(mock_model_and_tokenizer):
|
|||
for stats in results["stats"]:
|
||||
assert "total_tasks" in stats
|
||||
assert stats["total_tasks"] >= 0
|
||||
|
||||
|
||||
def test_prefill_skips_fully_cached_tasks(mock_model_and_tokenizer):
|
||||
"""Tasks whose entire prompt is cached skip the prefill phase."""
|
||||
mock_model, mock_tokenizer = mock_model_and_tokenizer
|
||||
|
||||
with patch("astrai.inference.core.scheduler.AutoModel"):
|
||||
with patch("astrai.inference.core.scheduler.AutoTokenizer"):
|
||||
scheduler = InferenceScheduler(
|
||||
model=mock_model,
|
||||
tokenizer=mock_tokenizer,
|
||||
max_batch_size=4,
|
||||
device="cpu",
|
||||
)
|
||||
|
||||
task_id = scheduler.add_task("short prompt", stream_callback=lambda t: None)
|
||||
scheduler.stop()
|
||||
assert task_id.startswith("task_")
|
||||
|
|
|
|||
Loading…
Reference in New Issue