refactor: 消除多处重复模式,统一工厂和参数传递

- AutoModel 继承 BaseFactory,消除自建 Registry(-30 行)
- executor.execute_prefill 删除重复 forward 代码块(bug)
- train_callback 移除 Protocol 上矛盾的 issubclass 检查
- engine.py 内部方法统一传 GenerationParams,校验内聚
- protocol.py SSEBuilder 类→函数,handle() 用 GenerationParams
- StreamContext 动态属性改为显式 dataclass 字段
- BaseFactory 新增 get_component_class 方法
This commit is contained in:
ViperEkura 2026-05-14 18:00:50 +08:00
parent 2196c34c52
commit 18fe6e9339
8 changed files with 84 additions and 147 deletions

View File

@ -155,6 +155,26 @@ class BaseFactory(ABC, Generic[T]):
""" """
pass pass
@classmethod
def get_component_class(cls, name: str) -> Type[T]:
"""Get the registered component class by name without instantiating it.
Args:
name: Registered name of the component
Returns:
The component class itself
Raises:
ValueError: If the component name is not registered
"""
if not cls._registry.contains(name):
raise ValueError(
f"Unknown component: '{name}'. "
f"Supported types: {sorted(cls._registry.list_names())}"
)
return cls._registry.get(name)
@classmethod @classmethod
def list_registered(cls) -> list: def list_registered(cls) -> list:
"""List all registered component names. """List all registered component names.

View File

@ -15,7 +15,6 @@ from astrai.inference.api import (
MessagesRequest, MessagesRequest,
OpenAIHandler, OpenAIHandler,
ProtocolHandler, ProtocolHandler,
SSEBuilder,
StopChecker, StopChecker,
StreamContext, StreamContext,
app, app,
@ -77,7 +76,6 @@ __all__ = [
"SamplingPipeline", "SamplingPipeline",
# Protocol # Protocol
"ProtocolHandler", "ProtocolHandler",
"SSEBuilder",
"StopChecker", "StopChecker",
"StreamContext", "StreamContext",
"AnthropicHandler", "AnthropicHandler",

View File

@ -4,7 +4,6 @@ from astrai.inference.api.protocol import (
AnthropicHandler, AnthropicHandler,
OpenAIHandler, OpenAIHandler,
ProtocolHandler, ProtocolHandler,
SSEBuilder,
StopChecker, StopChecker,
StreamContext, StreamContext,
) )
@ -21,7 +20,6 @@ __all__ = [
"AnthropicHandler", "AnthropicHandler",
"OpenAIHandler", "OpenAIHandler",
"ProtocolHandler", "ProtocolHandler",
"SSEBuilder",
"StopChecker", "StopChecker",
"StreamContext", "StreamContext",
"AnthropicMessage", "AnthropicMessage",

View File

@ -14,24 +14,20 @@ from typing import Any, Dict, List, Optional, Union
from fastapi.responses import StreamingResponse from fastapi.responses import StreamingResponse
from pydantic import BaseModel from pydantic import BaseModel
from astrai.inference.engine import InferenceEngine from astrai.inference.engine import GenerationParams, InferenceEngine
class SSEBuilder: def _sse_event(data: Dict[str, Any], event: Optional[str] = None) -> str:
"""Fluent builder for SSE (Server-Sent Events) formatted chunks.""" lines: List[str] = []
if event:
lines.append(f"event: {event}")
lines.append(f"data: {json.dumps(data, ensure_ascii=False)}")
lines.append("")
return "\n".join(lines)
@staticmethod
def event(data: Dict[str, Any], event: Optional[str] = None) -> str:
lines: List[str] = []
if event:
lines.append(f"event: {event}")
lines.append(f"data: {json.dumps(data, ensure_ascii=False)}")
lines.append("")
return "\n".join(lines)
@staticmethod def _sse_done() -> str:
def done() -> str: return "data: [DONE]\n\n"
return "data: [DONE]\n\n"
@dataclass @dataclass
@ -44,6 +40,8 @@ class StreamContext:
prompt_tokens: int prompt_tokens: int
completion_tokens: int = 0 completion_tokens: int = 0
accumulated: str = "" accumulated: str = ""
stop_matched: Optional[str] = None
last_yield_trimmed: str = ""
class StopChecker: class StopChecker:
@ -145,13 +143,13 @@ class ProtocolHandler(ABC):
prompt_tokens=self._count_prompt_tokens(), prompt_tokens=self._count_prompt_tokens(),
) )
agen = self.engine.generate_async( params = GenerationParams(
prompt=self.build_prompt(),
max_tokens=self.request.max_tokens, max_tokens=self.request.max_tokens,
temperature=self.request.temperature, temperature=self.request.temperature,
top_p=self.request.top_p, top_p=self.request.top_p,
top_k=self.request.top_k, top_k=self.request.top_k,
) )
agen = self.engine.generate_async(prompt=self.build_prompt(), params=params)
if self.request.stream: if self.request.stream:
return self._handle_stream(agen, ctx) return self._handle_stream(agen, ctx)
@ -180,7 +178,7 @@ class ProtocolHandler(ABC):
for event in self.format_stream_end(ctx): for event in self.format_stream_end(ctx):
yield event yield event
yield SSEBuilder.done() yield _sse_done()
return StreamingResponse( return StreamingResponse(
event_stream(), event_stream(),
@ -230,7 +228,7 @@ class OpenAIHandler(ProtocolHandler):
def format_stream_start(self, ctx: StreamContext) -> List[str]: def format_stream_start(self, ctx: StreamContext) -> List[str]:
return [ return [
SSEBuilder.event( _sse_event(
{ {
"id": ctx.resp_id, "id": ctx.resp_id,
"object": "chat.completion.chunk", "object": "chat.completion.chunk",
@ -248,7 +246,7 @@ class OpenAIHandler(ProtocolHandler):
] ]
def format_stream_token(self, ctx: StreamContext, token: str) -> str: def format_stream_token(self, ctx: StreamContext, token: str) -> str:
return SSEBuilder.event( return _sse_event(
{ {
"id": ctx.resp_id, "id": ctx.resp_id,
"object": "chat.completion.chunk", "object": "chat.completion.chunk",
@ -262,7 +260,7 @@ class OpenAIHandler(ProtocolHandler):
def format_stream_end(self, ctx: StreamContext) -> List[str]: def format_stream_end(self, ctx: StreamContext) -> List[str]:
return [ return [
SSEBuilder.event( _sse_event(
{ {
"id": ctx.resp_id, "id": ctx.resp_id,
"object": "chat.completion.chunk", "object": "chat.completion.chunk",
@ -271,7 +269,7 @@ class OpenAIHandler(ProtocolHandler):
"choices": [{"index": 0, "delta": {}, "finish_reason": "stop"}], "choices": [{"index": 0, "delta": {}, "finish_reason": "stop"}],
} }
), ),
SSEBuilder.event( _sse_event(
{ {
"prompt_tokens": ctx.prompt_tokens, "prompt_tokens": ctx.prompt_tokens,
"completion_tokens": ctx.completion_tokens, "completion_tokens": ctx.completion_tokens,
@ -334,16 +332,16 @@ class AnthropicHandler(ProtocolHandler):
if not matched: if not matched:
return None return None
ctx._stop_matched = matched ctx.stop_matched = matched
trimmed = ctx.accumulated[: ctx.accumulated.rfind(matched)] trimmed = ctx.accumulated[: ctx.accumulated.rfind(matched)]
unyielded = trimmed[len(self._yielded) :] unyielded = trimmed[len(self._yielded) :]
if unyielded: if unyielded:
ctx._last_yield_trimmed = unyielded ctx.last_yield_trimmed = unyielded
return matched return matched
def format_stream_start(self, ctx: StreamContext) -> List[str]: def format_stream_start(self, ctx: StreamContext) -> List[str]:
return [ return [
SSEBuilder.event( _sse_event(
{ {
"type": "message_start", "type": "message_start",
"message": { "message": {
@ -357,7 +355,7 @@ class AnthropicHandler(ProtocolHandler):
}, },
event="message_start", event="message_start",
), ),
SSEBuilder.event( _sse_event(
{ {
"type": "content_block_start", "type": "content_block_start",
"index": 0, "index": 0,
@ -369,7 +367,7 @@ class AnthropicHandler(ProtocolHandler):
def format_stream_token(self, ctx: StreamContext, token: str) -> str: def format_stream_token(self, ctx: StreamContext, token: str) -> str:
self._yielded += token self._yielded += token
return SSEBuilder.event( return _sse_event(
{ {
"type": "content_block_delta", "type": "content_block_delta",
"index": 0, "index": 0,
@ -379,12 +377,12 @@ class AnthropicHandler(ProtocolHandler):
) )
def format_stream_end(self, ctx: StreamContext) -> List[str]: def format_stream_end(self, ctx: StreamContext) -> List[str]:
matched = getattr(ctx, "_stop_matched", None) matched = ctx.stop_matched
events: List[str] = [] events: List[str] = []
last_yielded = getattr(ctx, "_last_yield_trimmed", "") last_yielded = ctx.last_yield_trimmed
if last_yielded: if last_yielded:
events.append( events.append(
SSEBuilder.event( _sse_event(
{ {
"type": "content_block_delta", "type": "content_block_delta",
"index": 0, "index": 0,
@ -394,13 +392,13 @@ class AnthropicHandler(ProtocolHandler):
) )
) )
events.append( events.append(
SSEBuilder.event( _sse_event(
{"type": "content_block_stop", "index": 0}, {"type": "content_block_stop", "index": 0},
event="content_block_stop", event="content_block_stop",
) )
) )
events.append( events.append(
SSEBuilder.event( _sse_event(
{ {
"type": "message_delta", "type": "message_delta",
"delta": { "delta": {
@ -412,13 +410,13 @@ class AnthropicHandler(ProtocolHandler):
event="message_delta", event="message_delta",
) )
) )
events.append(SSEBuilder.event({"type": "message_stop"}, event="message_stop")) events.append(_sse_event({"type": "message_stop"}, event="message_stop"))
return events return events
def format_non_stream_response( def format_non_stream_response(
self, ctx: StreamContext, content: str self, ctx: StreamContext, content: str
) -> Dict[str, Any]: ) -> Dict[str, Any]:
matched = getattr(ctx, "_stop_matched", None) matched = ctx.stop_matched
if matched: if matched:
content = content[: content.rfind(matched)] content = content[: content.rfind(matched)]
return { return {

View File

@ -60,25 +60,6 @@ class Executor:
paged_cache=self.page_cache.bind(page_tables, total_len=prompt_len), paged_cache=self.page_cache.bind(page_tables, total_len=prompt_len),
) )
for i, t in enumerate(tasks):
input_ids[i] = torch.tensor(
t.prompt_ids[start_pos:prompt_len], device=self.device
)
task_ids = [t.task_id for t in tasks]
page_tables = self.page_cache.make_table_tensor(task_ids, self.device)
with torch.inference_mode():
self.model(
input_ids,
position_ids=torch.arange(
start_pos, prompt_len, dtype=torch.long, device=self.device
)
.unsqueeze(0)
.expand(batch_sz, -1),
paged_cache=self.page_cache.bind(page_tables, total_len=prompt_len),
)
def execute_decode(self, tasks: List[Task]) -> List[int]: def execute_decode(self, tasks: List[Task]) -> List[int]:
if not tasks: if not tasks:
return [] return []

View File

@ -58,15 +58,6 @@ class GenerateResult:
return self.results.copy() return self.results.copy()
def _validate_params(top_k: int, top_p: float, temperature: float) -> None:
if not (isinstance(top_k, int) and top_k >= 0):
raise ValueError("top_k must be a non-negative integer")
if not (0.0 <= top_p <= 1.0):
raise ValueError("top_p must be a float between 0.0 and 1.0")
if not (isinstance(temperature, (int, float)) and temperature >= 0):
raise ValueError("temperature must be a non-negative number")
@dataclass(frozen=True) @dataclass(frozen=True)
class GenerationParams: class GenerationParams:
"""Immutable value object for sampling hyperparameters.""" """Immutable value object for sampling hyperparameters."""
@ -76,6 +67,14 @@ class GenerationParams:
temperature: float = 1.0 temperature: float = 1.0
max_tokens: int = 1024 max_tokens: int = 1024
def __post_init__(self):
if not (isinstance(self.top_k, int) and self.top_k >= 0):
raise ValueError("top_k must be a non-negative integer")
if not (0.0 <= self.top_p <= 1.0):
raise ValueError("top_p must be a float between 0.0 and 1.0")
if not (isinstance(self.temperature, (int, float)) and self.temperature >= 0):
raise ValueError("temperature must be a non-negative number")
class GenerationRequest: class GenerationRequest:
"""Request parameters for text generation.""" """Request parameters for text generation."""
@ -97,7 +96,6 @@ class GenerationRequest:
max_tokens=max_len, max_tokens=max_len,
) )
self.stream = stream self.stream = stream
_validate_params(top_k, top_p, temperature)
@property @property
def top_k(self) -> int: def top_k(self) -> int:
@ -157,31 +155,32 @@ class InferenceEngine:
top_p: float = 1.0, top_p: float = 1.0,
top_k: int = 50, top_k: int = 50,
) -> Union[Generator, str, List[str]]: ) -> Union[Generator, str, List[str]]:
_validate_params(top_k, top_p, temperature) params = GenerationParams(
top_k=top_k, top_p=top_p, temperature=temperature, max_tokens=max_tokens
)
is_batch = isinstance(prompt, list) is_batch = isinstance(prompt, list)
prompts = prompt if is_batch else [prompt] prompts = prompt if is_batch else [prompt]
if stream: if stream:
return self._generate_streaming( return self._generate_streaming(prompts, is_batch, params)
prompts, is_batch, max_tokens, temperature, top_p, top_k
)
else: else:
return self._generate_non_streaming( return self._generate_non_streaming(prompts, is_batch, params)
prompts, is_batch, max_tokens, temperature, top_p, top_k
)
def generate_async( def generate_async(
self, self,
prompt: str, prompt: str,
params: Optional[GenerationParams] = None,
max_tokens: int = 1024, max_tokens: int = 1024,
temperature: float = 1.0, temperature: float = 1.0,
top_p: float = 1.0, top_p: float = 1.0,
top_k: int = 50, top_k: int = 50,
) -> AsyncGenerator[str, None]: ) -> AsyncGenerator[str, None]:
sync_gen = self._generate_streaming( if params is None:
[prompt], False, max_tokens, temperature, top_p, top_k params = GenerationParams(
) top_k=top_k, top_p=top_p, temperature=temperature, max_tokens=max_tokens
)
sync_gen = self._generate_streaming([prompt], False, params)
async def _agen(): async def _agen():
loop = asyncio.get_event_loop() loop = asyncio.get_event_loop()
@ -214,12 +213,7 @@ class InferenceEngine:
) )
def _submit_tasks( def _submit_tasks(
self, self, prompts: List[str], params: GenerationParams
prompts: List[str],
max_tokens: int,
temperature: float,
top_p: float,
top_k: int,
) -> Tuple[GenerateResult, List[str]]: ) -> Tuple[GenerateResult, List[str]]:
n = len(prompts) n = len(prompts)
result = GenerateResult(count=n) result = GenerateResult(count=n)
@ -228,10 +222,10 @@ class InferenceEngine:
cb = self._make_callback(result, i) cb = self._make_callback(result, i)
task_id = self.scheduler.add_task( task_id = self.scheduler.add_task(
prompt=p, prompt=p,
max_tokens=max_tokens, max_tokens=params.max_tokens,
temperature=temperature, temperature=params.temperature,
top_p=top_p, top_p=params.top_p,
top_k=top_k, top_k=params.top_k,
stream_callback=cb, stream_callback=cb,
) )
task_ids.append(task_id) task_ids.append(task_id)
@ -245,17 +239,9 @@ class InferenceEngine:
return cb return cb
def _generate_streaming( def _generate_streaming(
self, self, prompts: List[str], is_batch: bool, params: GenerationParams
prompts: List[str],
is_batch: bool,
max_tokens: int,
temperature: float,
top_p: float,
top_k: int,
) -> Generator: ) -> Generator:
result, task_ids = self._submit_tasks( result, task_ids = self._submit_tasks(prompts, params)
prompts, max_tokens, temperature, top_p, top_k
)
n = len(prompts) n = len(prompts)
remaining = n remaining = n
finished = [False] * n finished = [False] * n
@ -281,17 +267,9 @@ class InferenceEngine:
return gen() return gen()
def _generate_non_streaming( def _generate_non_streaming(
self, self, prompts: List[str], is_batch: bool, params: GenerationParams
prompts: List[str],
is_batch: bool,
max_tokens: int,
temperature: float,
top_p: float,
top_k: int,
) -> Union[str, List[str]]: ) -> Union[str, List[str]]:
result, task_ids = self._submit_tasks( result, task_ids = self._submit_tasks(prompts, params)
prompts, max_tokens, temperature, top_p, top_k
)
result.wait_completion() result.wait_completion()

View File

@ -4,13 +4,13 @@ AutoModel base class for model loading and saving.
from contextlib import contextmanager from contextlib import contextmanager
from pathlib import Path from pathlib import Path
from typing import Self, Type, Union from typing import Self, Union
import safetensors.torch as st import safetensors.torch as st
import torch.nn as nn import torch.nn as nn
from astrai.config import ModelConfig from astrai.config import ModelConfig
from astrai.factory import Registry from astrai.factory import BaseFactory
@contextmanager @contextmanager
@ -39,46 +39,16 @@ def _disable_random_init(enable: bool = True):
setattr(nn.init, name, orig_func) setattr(nn.init, name, orig_func)
class AutoModel(nn.Module): class AutoModel(BaseFactory["AutoModel"], nn.Module):
""" """
Autoregressive language model base class. Autoregressive language model base class.
Provides model loading/saving and generation capabilities. Provides model loading/saving, registration, and generation.
""" """
_registry = Registry()
def __init__(self, config: ModelConfig): def __init__(self, config: ModelConfig):
super().__init__() super().__init__()
self.config = config self.config = config
@classmethod
def register(cls, model_type: str):
"""
Class method decorator to register model type.
Usage:
@AutoModel.register('transformer')
class Transformer(AutoModel):
...
"""
def decorator(sub_cls: Type["AutoModel"]) -> Type["AutoModel"]:
cls._registry.register(model_type.lower(), sub_cls)
return sub_cls
return decorator
@classmethod
def get_model_class(cls, model_type: str) -> Type["AutoModel"]:
"""Get model class by model_type string."""
model_type = model_type.lower()
if not cls._registry.contains(model_type):
available = cls._registry.list_names()
raise ValueError(
f"Unknown model_type: {model_type}. Available: {available}"
)
return cls._registry.get(model_type)
@classmethod @classmethod
def from_pretrained( def from_pretrained(
cls, cls,
@ -98,7 +68,7 @@ class AutoModel(nn.Module):
raise FileNotFoundError(f"Config file not found: {config_path}") raise FileNotFoundError(f"Config file not found: {config_path}")
model_type = config.model_type or "transformer" model_type = config.model_type or "transformer"
actual_cls = cls.get_model_class(model_type) actual_cls = AutoModel.get_component_class(model_type)
with _disable_random_init(enable=disable_random_init): with _disable_random_init(enable=disable_random_init):
model = actual_cls(config) model = actual_cls(config)

View File

@ -69,12 +69,6 @@ class CallbackFactory(BaseFactory[TrainCallback]):
callback = CallbackFactory.create("my_callback", **kwargs) callback = CallbackFactory.create("my_callback", **kwargs)
""" """
@classmethod
def _validate_component(cls, callback_cls: type) -> None:
"""Validate that the callback class inherits from TrainCallback."""
if not issubclass(callback_cls, TrainCallback):
raise TypeError(f"{callback_cls.__name__} must inherit from TrainCallback")
@CallbackFactory.register("gradient_clipping") @CallbackFactory.register("gradient_clipping")
class GradientClippingCallback(TrainCallback): class GradientClippingCallback(TrainCallback):