refactor: 消除多处重复模式,统一工厂和参数传递
- AutoModel 继承 BaseFactory,消除自建 Registry(-30 行) - executor.execute_prefill 删除重复 forward 代码块(bug) - train_callback 移除 Protocol 上矛盾的 issubclass 检查 - engine.py 内部方法统一传 GenerationParams,校验内聚 - protocol.py SSEBuilder 类→函数,handle() 用 GenerationParams - StreamContext 动态属性改为显式 dataclass 字段 - BaseFactory 新增 get_component_class 方法
This commit is contained in:
parent
2196c34c52
commit
18fe6e9339
|
|
@ -155,6 +155,26 @@ class BaseFactory(ABC, Generic[T]):
|
|||
"""
|
||||
pass
|
||||
|
||||
@classmethod
|
||||
def get_component_class(cls, name: str) -> Type[T]:
|
||||
"""Get the registered component class by name without instantiating it.
|
||||
|
||||
Args:
|
||||
name: Registered name of the component
|
||||
|
||||
Returns:
|
||||
The component class itself
|
||||
|
||||
Raises:
|
||||
ValueError: If the component name is not registered
|
||||
"""
|
||||
if not cls._registry.contains(name):
|
||||
raise ValueError(
|
||||
f"Unknown component: '{name}'. "
|
||||
f"Supported types: {sorted(cls._registry.list_names())}"
|
||||
)
|
||||
return cls._registry.get(name)
|
||||
|
||||
@classmethod
|
||||
def list_registered(cls) -> list:
|
||||
"""List all registered component names.
|
||||
|
|
|
|||
|
|
@ -15,7 +15,6 @@ from astrai.inference.api import (
|
|||
MessagesRequest,
|
||||
OpenAIHandler,
|
||||
ProtocolHandler,
|
||||
SSEBuilder,
|
||||
StopChecker,
|
||||
StreamContext,
|
||||
app,
|
||||
|
|
@ -77,7 +76,6 @@ __all__ = [
|
|||
"SamplingPipeline",
|
||||
# Protocol
|
||||
"ProtocolHandler",
|
||||
"SSEBuilder",
|
||||
"StopChecker",
|
||||
"StreamContext",
|
||||
"AnthropicHandler",
|
||||
|
|
|
|||
|
|
@ -4,7 +4,6 @@ from astrai.inference.api.protocol import (
|
|||
AnthropicHandler,
|
||||
OpenAIHandler,
|
||||
ProtocolHandler,
|
||||
SSEBuilder,
|
||||
StopChecker,
|
||||
StreamContext,
|
||||
)
|
||||
|
|
@ -21,7 +20,6 @@ __all__ = [
|
|||
"AnthropicHandler",
|
||||
"OpenAIHandler",
|
||||
"ProtocolHandler",
|
||||
"SSEBuilder",
|
||||
"StopChecker",
|
||||
"StreamContext",
|
||||
"AnthropicMessage",
|
||||
|
|
|
|||
|
|
@ -14,24 +14,20 @@ from typing import Any, Dict, List, Optional, Union
|
|||
from fastapi.responses import StreamingResponse
|
||||
from pydantic import BaseModel
|
||||
|
||||
from astrai.inference.engine import InferenceEngine
|
||||
from astrai.inference.engine import GenerationParams, InferenceEngine
|
||||
|
||||
|
||||
class SSEBuilder:
|
||||
"""Fluent builder for SSE (Server-Sent Events) formatted chunks."""
|
||||
def _sse_event(data: Dict[str, Any], event: Optional[str] = None) -> str:
|
||||
lines: List[str] = []
|
||||
if event:
|
||||
lines.append(f"event: {event}")
|
||||
lines.append(f"data: {json.dumps(data, ensure_ascii=False)}")
|
||||
lines.append("")
|
||||
return "\n".join(lines)
|
||||
|
||||
@staticmethod
|
||||
def event(data: Dict[str, Any], event: Optional[str] = None) -> str:
|
||||
lines: List[str] = []
|
||||
if event:
|
||||
lines.append(f"event: {event}")
|
||||
lines.append(f"data: {json.dumps(data, ensure_ascii=False)}")
|
||||
lines.append("")
|
||||
return "\n".join(lines)
|
||||
|
||||
@staticmethod
|
||||
def done() -> str:
|
||||
return "data: [DONE]\n\n"
|
||||
def _sse_done() -> str:
|
||||
return "data: [DONE]\n\n"
|
||||
|
||||
|
||||
@dataclass
|
||||
|
|
@ -44,6 +40,8 @@ class StreamContext:
|
|||
prompt_tokens: int
|
||||
completion_tokens: int = 0
|
||||
accumulated: str = ""
|
||||
stop_matched: Optional[str] = None
|
||||
last_yield_trimmed: str = ""
|
||||
|
||||
|
||||
class StopChecker:
|
||||
|
|
@ -145,13 +143,13 @@ class ProtocolHandler(ABC):
|
|||
prompt_tokens=self._count_prompt_tokens(),
|
||||
)
|
||||
|
||||
agen = self.engine.generate_async(
|
||||
prompt=self.build_prompt(),
|
||||
params = GenerationParams(
|
||||
max_tokens=self.request.max_tokens,
|
||||
temperature=self.request.temperature,
|
||||
top_p=self.request.top_p,
|
||||
top_k=self.request.top_k,
|
||||
)
|
||||
agen = self.engine.generate_async(prompt=self.build_prompt(), params=params)
|
||||
|
||||
if self.request.stream:
|
||||
return self._handle_stream(agen, ctx)
|
||||
|
|
@ -180,7 +178,7 @@ class ProtocolHandler(ABC):
|
|||
|
||||
for event in self.format_stream_end(ctx):
|
||||
yield event
|
||||
yield SSEBuilder.done()
|
||||
yield _sse_done()
|
||||
|
||||
return StreamingResponse(
|
||||
event_stream(),
|
||||
|
|
@ -230,7 +228,7 @@ class OpenAIHandler(ProtocolHandler):
|
|||
|
||||
def format_stream_start(self, ctx: StreamContext) -> List[str]:
|
||||
return [
|
||||
SSEBuilder.event(
|
||||
_sse_event(
|
||||
{
|
||||
"id": ctx.resp_id,
|
||||
"object": "chat.completion.chunk",
|
||||
|
|
@ -248,7 +246,7 @@ class OpenAIHandler(ProtocolHandler):
|
|||
]
|
||||
|
||||
def format_stream_token(self, ctx: StreamContext, token: str) -> str:
|
||||
return SSEBuilder.event(
|
||||
return _sse_event(
|
||||
{
|
||||
"id": ctx.resp_id,
|
||||
"object": "chat.completion.chunk",
|
||||
|
|
@ -262,7 +260,7 @@ class OpenAIHandler(ProtocolHandler):
|
|||
|
||||
def format_stream_end(self, ctx: StreamContext) -> List[str]:
|
||||
return [
|
||||
SSEBuilder.event(
|
||||
_sse_event(
|
||||
{
|
||||
"id": ctx.resp_id,
|
||||
"object": "chat.completion.chunk",
|
||||
|
|
@ -271,7 +269,7 @@ class OpenAIHandler(ProtocolHandler):
|
|||
"choices": [{"index": 0, "delta": {}, "finish_reason": "stop"}],
|
||||
}
|
||||
),
|
||||
SSEBuilder.event(
|
||||
_sse_event(
|
||||
{
|
||||
"prompt_tokens": ctx.prompt_tokens,
|
||||
"completion_tokens": ctx.completion_tokens,
|
||||
|
|
@ -334,16 +332,16 @@ class AnthropicHandler(ProtocolHandler):
|
|||
if not matched:
|
||||
return None
|
||||
|
||||
ctx._stop_matched = matched
|
||||
ctx.stop_matched = matched
|
||||
trimmed = ctx.accumulated[: ctx.accumulated.rfind(matched)]
|
||||
unyielded = trimmed[len(self._yielded) :]
|
||||
if unyielded:
|
||||
ctx._last_yield_trimmed = unyielded
|
||||
ctx.last_yield_trimmed = unyielded
|
||||
return matched
|
||||
|
||||
def format_stream_start(self, ctx: StreamContext) -> List[str]:
|
||||
return [
|
||||
SSEBuilder.event(
|
||||
_sse_event(
|
||||
{
|
||||
"type": "message_start",
|
||||
"message": {
|
||||
|
|
@ -357,7 +355,7 @@ class AnthropicHandler(ProtocolHandler):
|
|||
},
|
||||
event="message_start",
|
||||
),
|
||||
SSEBuilder.event(
|
||||
_sse_event(
|
||||
{
|
||||
"type": "content_block_start",
|
||||
"index": 0,
|
||||
|
|
@ -369,7 +367,7 @@ class AnthropicHandler(ProtocolHandler):
|
|||
|
||||
def format_stream_token(self, ctx: StreamContext, token: str) -> str:
|
||||
self._yielded += token
|
||||
return SSEBuilder.event(
|
||||
return _sse_event(
|
||||
{
|
||||
"type": "content_block_delta",
|
||||
"index": 0,
|
||||
|
|
@ -379,12 +377,12 @@ class AnthropicHandler(ProtocolHandler):
|
|||
)
|
||||
|
||||
def format_stream_end(self, ctx: StreamContext) -> List[str]:
|
||||
matched = getattr(ctx, "_stop_matched", None)
|
||||
matched = ctx.stop_matched
|
||||
events: List[str] = []
|
||||
last_yielded = getattr(ctx, "_last_yield_trimmed", "")
|
||||
last_yielded = ctx.last_yield_trimmed
|
||||
if last_yielded:
|
||||
events.append(
|
||||
SSEBuilder.event(
|
||||
_sse_event(
|
||||
{
|
||||
"type": "content_block_delta",
|
||||
"index": 0,
|
||||
|
|
@ -394,13 +392,13 @@ class AnthropicHandler(ProtocolHandler):
|
|||
)
|
||||
)
|
||||
events.append(
|
||||
SSEBuilder.event(
|
||||
_sse_event(
|
||||
{"type": "content_block_stop", "index": 0},
|
||||
event="content_block_stop",
|
||||
)
|
||||
)
|
||||
events.append(
|
||||
SSEBuilder.event(
|
||||
_sse_event(
|
||||
{
|
||||
"type": "message_delta",
|
||||
"delta": {
|
||||
|
|
@ -412,13 +410,13 @@ class AnthropicHandler(ProtocolHandler):
|
|||
event="message_delta",
|
||||
)
|
||||
)
|
||||
events.append(SSEBuilder.event({"type": "message_stop"}, event="message_stop"))
|
||||
events.append(_sse_event({"type": "message_stop"}, event="message_stop"))
|
||||
return events
|
||||
|
||||
def format_non_stream_response(
|
||||
self, ctx: StreamContext, content: str
|
||||
) -> Dict[str, Any]:
|
||||
matched = getattr(ctx, "_stop_matched", None)
|
||||
matched = ctx.stop_matched
|
||||
if matched:
|
||||
content = content[: content.rfind(matched)]
|
||||
return {
|
||||
|
|
|
|||
|
|
@ -60,25 +60,6 @@ class Executor:
|
|||
paged_cache=self.page_cache.bind(page_tables, total_len=prompt_len),
|
||||
)
|
||||
|
||||
for i, t in enumerate(tasks):
|
||||
input_ids[i] = torch.tensor(
|
||||
t.prompt_ids[start_pos:prompt_len], device=self.device
|
||||
)
|
||||
|
||||
task_ids = [t.task_id for t in tasks]
|
||||
page_tables = self.page_cache.make_table_tensor(task_ids, self.device)
|
||||
|
||||
with torch.inference_mode():
|
||||
self.model(
|
||||
input_ids,
|
||||
position_ids=torch.arange(
|
||||
start_pos, prompt_len, dtype=torch.long, device=self.device
|
||||
)
|
||||
.unsqueeze(0)
|
||||
.expand(batch_sz, -1),
|
||||
paged_cache=self.page_cache.bind(page_tables, total_len=prompt_len),
|
||||
)
|
||||
|
||||
def execute_decode(self, tasks: List[Task]) -> List[int]:
|
||||
if not tasks:
|
||||
return []
|
||||
|
|
|
|||
|
|
@ -58,15 +58,6 @@ class GenerateResult:
|
|||
return self.results.copy()
|
||||
|
||||
|
||||
def _validate_params(top_k: int, top_p: float, temperature: float) -> None:
|
||||
if not (isinstance(top_k, int) and top_k >= 0):
|
||||
raise ValueError("top_k must be a non-negative integer")
|
||||
if not (0.0 <= top_p <= 1.0):
|
||||
raise ValueError("top_p must be a float between 0.0 and 1.0")
|
||||
if not (isinstance(temperature, (int, float)) and temperature >= 0):
|
||||
raise ValueError("temperature must be a non-negative number")
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class GenerationParams:
|
||||
"""Immutable value object for sampling hyperparameters."""
|
||||
|
|
@ -76,6 +67,14 @@ class GenerationParams:
|
|||
temperature: float = 1.0
|
||||
max_tokens: int = 1024
|
||||
|
||||
def __post_init__(self):
|
||||
if not (isinstance(self.top_k, int) and self.top_k >= 0):
|
||||
raise ValueError("top_k must be a non-negative integer")
|
||||
if not (0.0 <= self.top_p <= 1.0):
|
||||
raise ValueError("top_p must be a float between 0.0 and 1.0")
|
||||
if not (isinstance(self.temperature, (int, float)) and self.temperature >= 0):
|
||||
raise ValueError("temperature must be a non-negative number")
|
||||
|
||||
|
||||
class GenerationRequest:
|
||||
"""Request parameters for text generation."""
|
||||
|
|
@ -97,7 +96,6 @@ class GenerationRequest:
|
|||
max_tokens=max_len,
|
||||
)
|
||||
self.stream = stream
|
||||
_validate_params(top_k, top_p, temperature)
|
||||
|
||||
@property
|
||||
def top_k(self) -> int:
|
||||
|
|
@ -157,31 +155,32 @@ class InferenceEngine:
|
|||
top_p: float = 1.0,
|
||||
top_k: int = 50,
|
||||
) -> Union[Generator, str, List[str]]:
|
||||
_validate_params(top_k, top_p, temperature)
|
||||
params = GenerationParams(
|
||||
top_k=top_k, top_p=top_p, temperature=temperature, max_tokens=max_tokens
|
||||
)
|
||||
|
||||
is_batch = isinstance(prompt, list)
|
||||
prompts = prompt if is_batch else [prompt]
|
||||
|
||||
if stream:
|
||||
return self._generate_streaming(
|
||||
prompts, is_batch, max_tokens, temperature, top_p, top_k
|
||||
)
|
||||
return self._generate_streaming(prompts, is_batch, params)
|
||||
else:
|
||||
return self._generate_non_streaming(
|
||||
prompts, is_batch, max_tokens, temperature, top_p, top_k
|
||||
)
|
||||
return self._generate_non_streaming(prompts, is_batch, params)
|
||||
|
||||
def generate_async(
|
||||
self,
|
||||
prompt: str,
|
||||
params: Optional[GenerationParams] = None,
|
||||
max_tokens: int = 1024,
|
||||
temperature: float = 1.0,
|
||||
top_p: float = 1.0,
|
||||
top_k: int = 50,
|
||||
) -> AsyncGenerator[str, None]:
|
||||
sync_gen = self._generate_streaming(
|
||||
[prompt], False, max_tokens, temperature, top_p, top_k
|
||||
)
|
||||
if params is None:
|
||||
params = GenerationParams(
|
||||
top_k=top_k, top_p=top_p, temperature=temperature, max_tokens=max_tokens
|
||||
)
|
||||
sync_gen = self._generate_streaming([prompt], False, params)
|
||||
|
||||
async def _agen():
|
||||
loop = asyncio.get_event_loop()
|
||||
|
|
@ -214,12 +213,7 @@ class InferenceEngine:
|
|||
)
|
||||
|
||||
def _submit_tasks(
|
||||
self,
|
||||
prompts: List[str],
|
||||
max_tokens: int,
|
||||
temperature: float,
|
||||
top_p: float,
|
||||
top_k: int,
|
||||
self, prompts: List[str], params: GenerationParams
|
||||
) -> Tuple[GenerateResult, List[str]]:
|
||||
n = len(prompts)
|
||||
result = GenerateResult(count=n)
|
||||
|
|
@ -228,10 +222,10 @@ class InferenceEngine:
|
|||
cb = self._make_callback(result, i)
|
||||
task_id = self.scheduler.add_task(
|
||||
prompt=p,
|
||||
max_tokens=max_tokens,
|
||||
temperature=temperature,
|
||||
top_p=top_p,
|
||||
top_k=top_k,
|
||||
max_tokens=params.max_tokens,
|
||||
temperature=params.temperature,
|
||||
top_p=params.top_p,
|
||||
top_k=params.top_k,
|
||||
stream_callback=cb,
|
||||
)
|
||||
task_ids.append(task_id)
|
||||
|
|
@ -245,17 +239,9 @@ class InferenceEngine:
|
|||
return cb
|
||||
|
||||
def _generate_streaming(
|
||||
self,
|
||||
prompts: List[str],
|
||||
is_batch: bool,
|
||||
max_tokens: int,
|
||||
temperature: float,
|
||||
top_p: float,
|
||||
top_k: int,
|
||||
self, prompts: List[str], is_batch: bool, params: GenerationParams
|
||||
) -> Generator:
|
||||
result, task_ids = self._submit_tasks(
|
||||
prompts, max_tokens, temperature, top_p, top_k
|
||||
)
|
||||
result, task_ids = self._submit_tasks(prompts, params)
|
||||
n = len(prompts)
|
||||
remaining = n
|
||||
finished = [False] * n
|
||||
|
|
@ -281,17 +267,9 @@ class InferenceEngine:
|
|||
return gen()
|
||||
|
||||
def _generate_non_streaming(
|
||||
self,
|
||||
prompts: List[str],
|
||||
is_batch: bool,
|
||||
max_tokens: int,
|
||||
temperature: float,
|
||||
top_p: float,
|
||||
top_k: int,
|
||||
self, prompts: List[str], is_batch: bool, params: GenerationParams
|
||||
) -> Union[str, List[str]]:
|
||||
result, task_ids = self._submit_tasks(
|
||||
prompts, max_tokens, temperature, top_p, top_k
|
||||
)
|
||||
result, task_ids = self._submit_tasks(prompts, params)
|
||||
|
||||
result.wait_completion()
|
||||
|
||||
|
|
|
|||
|
|
@ -4,13 +4,13 @@ AutoModel base class for model loading and saving.
|
|||
|
||||
from contextlib import contextmanager
|
||||
from pathlib import Path
|
||||
from typing import Self, Type, Union
|
||||
from typing import Self, Union
|
||||
|
||||
import safetensors.torch as st
|
||||
import torch.nn as nn
|
||||
|
||||
from astrai.config import ModelConfig
|
||||
from astrai.factory import Registry
|
||||
from astrai.factory import BaseFactory
|
||||
|
||||
|
||||
@contextmanager
|
||||
|
|
@ -39,46 +39,16 @@ def _disable_random_init(enable: bool = True):
|
|||
setattr(nn.init, name, orig_func)
|
||||
|
||||
|
||||
class AutoModel(nn.Module):
|
||||
class AutoModel(BaseFactory["AutoModel"], nn.Module):
|
||||
"""
|
||||
Autoregressive language model base class.
|
||||
Provides model loading/saving and generation capabilities.
|
||||
Provides model loading/saving, registration, and generation.
|
||||
"""
|
||||
|
||||
_registry = Registry()
|
||||
|
||||
def __init__(self, config: ModelConfig):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
|
||||
@classmethod
|
||||
def register(cls, model_type: str):
|
||||
"""
|
||||
Class method decorator to register model type.
|
||||
|
||||
Usage:
|
||||
@AutoModel.register('transformer')
|
||||
class Transformer(AutoModel):
|
||||
...
|
||||
"""
|
||||
|
||||
def decorator(sub_cls: Type["AutoModel"]) -> Type["AutoModel"]:
|
||||
cls._registry.register(model_type.lower(), sub_cls)
|
||||
return sub_cls
|
||||
|
||||
return decorator
|
||||
|
||||
@classmethod
|
||||
def get_model_class(cls, model_type: str) -> Type["AutoModel"]:
|
||||
"""Get model class by model_type string."""
|
||||
model_type = model_type.lower()
|
||||
if not cls._registry.contains(model_type):
|
||||
available = cls._registry.list_names()
|
||||
raise ValueError(
|
||||
f"Unknown model_type: {model_type}. Available: {available}"
|
||||
)
|
||||
return cls._registry.get(model_type)
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(
|
||||
cls,
|
||||
|
|
@ -98,7 +68,7 @@ class AutoModel(nn.Module):
|
|||
raise FileNotFoundError(f"Config file not found: {config_path}")
|
||||
|
||||
model_type = config.model_type or "transformer"
|
||||
actual_cls = cls.get_model_class(model_type)
|
||||
actual_cls = AutoModel.get_component_class(model_type)
|
||||
|
||||
with _disable_random_init(enable=disable_random_init):
|
||||
model = actual_cls(config)
|
||||
|
|
|
|||
|
|
@ -69,12 +69,6 @@ class CallbackFactory(BaseFactory[TrainCallback]):
|
|||
callback = CallbackFactory.create("my_callback", **kwargs)
|
||||
"""
|
||||
|
||||
@classmethod
|
||||
def _validate_component(cls, callback_cls: type) -> None:
|
||||
"""Validate that the callback class inherits from TrainCallback."""
|
||||
if not issubclass(callback_cls, TrainCallback):
|
||||
raise TypeError(f"{callback_cls.__name__} must inherit from TrainCallback")
|
||||
|
||||
|
||||
@CallbackFactory.register("gradient_clipping")
|
||||
class GradientClippingCallback(TrainCallback):
|
||||
|
|
|
|||
Loading…
Reference in New Issue