refactor: 消除多处重复模式,统一工厂和参数传递

- AutoModel 继承 BaseFactory,消除自建 Registry(-30 行)
- executor.execute_prefill 删除重复 forward 代码块(bug)
- train_callback 移除 Protocol 上矛盾的 issubclass 检查
- engine.py 内部方法统一传 GenerationParams,校验内聚
- protocol.py SSEBuilder 类→函数,handle() 用 GenerationParams
- StreamContext 动态属性改为显式 dataclass 字段
- BaseFactory 新增 get_component_class 方法
This commit is contained in:
ViperEkura 2026-05-14 18:00:50 +08:00
parent 2196c34c52
commit 18fe6e9339
8 changed files with 84 additions and 147 deletions

View File

@ -155,6 +155,26 @@ class BaseFactory(ABC, Generic[T]):
"""
pass
@classmethod
def get_component_class(cls, name: str) -> Type[T]:
"""Get the registered component class by name without instantiating it.
Args:
name: Registered name of the component
Returns:
The component class itself
Raises:
ValueError: If the component name is not registered
"""
if not cls._registry.contains(name):
raise ValueError(
f"Unknown component: '{name}'. "
f"Supported types: {sorted(cls._registry.list_names())}"
)
return cls._registry.get(name)
@classmethod
def list_registered(cls) -> list:
"""List all registered component names.

View File

@ -15,7 +15,6 @@ from astrai.inference.api import (
MessagesRequest,
OpenAIHandler,
ProtocolHandler,
SSEBuilder,
StopChecker,
StreamContext,
app,
@ -77,7 +76,6 @@ __all__ = [
"SamplingPipeline",
# Protocol
"ProtocolHandler",
"SSEBuilder",
"StopChecker",
"StreamContext",
"AnthropicHandler",

View File

@ -4,7 +4,6 @@ from astrai.inference.api.protocol import (
AnthropicHandler,
OpenAIHandler,
ProtocolHandler,
SSEBuilder,
StopChecker,
StreamContext,
)
@ -21,7 +20,6 @@ __all__ = [
"AnthropicHandler",
"OpenAIHandler",
"ProtocolHandler",
"SSEBuilder",
"StopChecker",
"StreamContext",
"AnthropicMessage",

View File

@ -14,14 +14,10 @@ from typing import Any, Dict, List, Optional, Union
from fastapi.responses import StreamingResponse
from pydantic import BaseModel
from astrai.inference.engine import InferenceEngine
from astrai.inference.engine import GenerationParams, InferenceEngine
class SSEBuilder:
"""Fluent builder for SSE (Server-Sent Events) formatted chunks."""
@staticmethod
def event(data: Dict[str, Any], event: Optional[str] = None) -> str:
def _sse_event(data: Dict[str, Any], event: Optional[str] = None) -> str:
lines: List[str] = []
if event:
lines.append(f"event: {event}")
@ -29,8 +25,8 @@ class SSEBuilder:
lines.append("")
return "\n".join(lines)
@staticmethod
def done() -> str:
def _sse_done() -> str:
return "data: [DONE]\n\n"
@ -44,6 +40,8 @@ class StreamContext:
prompt_tokens: int
completion_tokens: int = 0
accumulated: str = ""
stop_matched: Optional[str] = None
last_yield_trimmed: str = ""
class StopChecker:
@ -145,13 +143,13 @@ class ProtocolHandler(ABC):
prompt_tokens=self._count_prompt_tokens(),
)
agen = self.engine.generate_async(
prompt=self.build_prompt(),
params = GenerationParams(
max_tokens=self.request.max_tokens,
temperature=self.request.temperature,
top_p=self.request.top_p,
top_k=self.request.top_k,
)
agen = self.engine.generate_async(prompt=self.build_prompt(), params=params)
if self.request.stream:
return self._handle_stream(agen, ctx)
@ -180,7 +178,7 @@ class ProtocolHandler(ABC):
for event in self.format_stream_end(ctx):
yield event
yield SSEBuilder.done()
yield _sse_done()
return StreamingResponse(
event_stream(),
@ -230,7 +228,7 @@ class OpenAIHandler(ProtocolHandler):
def format_stream_start(self, ctx: StreamContext) -> List[str]:
return [
SSEBuilder.event(
_sse_event(
{
"id": ctx.resp_id,
"object": "chat.completion.chunk",
@ -248,7 +246,7 @@ class OpenAIHandler(ProtocolHandler):
]
def format_stream_token(self, ctx: StreamContext, token: str) -> str:
return SSEBuilder.event(
return _sse_event(
{
"id": ctx.resp_id,
"object": "chat.completion.chunk",
@ -262,7 +260,7 @@ class OpenAIHandler(ProtocolHandler):
def format_stream_end(self, ctx: StreamContext) -> List[str]:
return [
SSEBuilder.event(
_sse_event(
{
"id": ctx.resp_id,
"object": "chat.completion.chunk",
@ -271,7 +269,7 @@ class OpenAIHandler(ProtocolHandler):
"choices": [{"index": 0, "delta": {}, "finish_reason": "stop"}],
}
),
SSEBuilder.event(
_sse_event(
{
"prompt_tokens": ctx.prompt_tokens,
"completion_tokens": ctx.completion_tokens,
@ -334,16 +332,16 @@ class AnthropicHandler(ProtocolHandler):
if not matched:
return None
ctx._stop_matched = matched
ctx.stop_matched = matched
trimmed = ctx.accumulated[: ctx.accumulated.rfind(matched)]
unyielded = trimmed[len(self._yielded) :]
if unyielded:
ctx._last_yield_trimmed = unyielded
ctx.last_yield_trimmed = unyielded
return matched
def format_stream_start(self, ctx: StreamContext) -> List[str]:
return [
SSEBuilder.event(
_sse_event(
{
"type": "message_start",
"message": {
@ -357,7 +355,7 @@ class AnthropicHandler(ProtocolHandler):
},
event="message_start",
),
SSEBuilder.event(
_sse_event(
{
"type": "content_block_start",
"index": 0,
@ -369,7 +367,7 @@ class AnthropicHandler(ProtocolHandler):
def format_stream_token(self, ctx: StreamContext, token: str) -> str:
self._yielded += token
return SSEBuilder.event(
return _sse_event(
{
"type": "content_block_delta",
"index": 0,
@ -379,12 +377,12 @@ class AnthropicHandler(ProtocolHandler):
)
def format_stream_end(self, ctx: StreamContext) -> List[str]:
matched = getattr(ctx, "_stop_matched", None)
matched = ctx.stop_matched
events: List[str] = []
last_yielded = getattr(ctx, "_last_yield_trimmed", "")
last_yielded = ctx.last_yield_trimmed
if last_yielded:
events.append(
SSEBuilder.event(
_sse_event(
{
"type": "content_block_delta",
"index": 0,
@ -394,13 +392,13 @@ class AnthropicHandler(ProtocolHandler):
)
)
events.append(
SSEBuilder.event(
_sse_event(
{"type": "content_block_stop", "index": 0},
event="content_block_stop",
)
)
events.append(
SSEBuilder.event(
_sse_event(
{
"type": "message_delta",
"delta": {
@ -412,13 +410,13 @@ class AnthropicHandler(ProtocolHandler):
event="message_delta",
)
)
events.append(SSEBuilder.event({"type": "message_stop"}, event="message_stop"))
events.append(_sse_event({"type": "message_stop"}, event="message_stop"))
return events
def format_non_stream_response(
self, ctx: StreamContext, content: str
) -> Dict[str, Any]:
matched = getattr(ctx, "_stop_matched", None)
matched = ctx.stop_matched
if matched:
content = content[: content.rfind(matched)]
return {

View File

@ -60,25 +60,6 @@ class Executor:
paged_cache=self.page_cache.bind(page_tables, total_len=prompt_len),
)
for i, t in enumerate(tasks):
input_ids[i] = torch.tensor(
t.prompt_ids[start_pos:prompt_len], device=self.device
)
task_ids = [t.task_id for t in tasks]
page_tables = self.page_cache.make_table_tensor(task_ids, self.device)
with torch.inference_mode():
self.model(
input_ids,
position_ids=torch.arange(
start_pos, prompt_len, dtype=torch.long, device=self.device
)
.unsqueeze(0)
.expand(batch_sz, -1),
paged_cache=self.page_cache.bind(page_tables, total_len=prompt_len),
)
def execute_decode(self, tasks: List[Task]) -> List[int]:
if not tasks:
return []

View File

@ -58,15 +58,6 @@ class GenerateResult:
return self.results.copy()
def _validate_params(top_k: int, top_p: float, temperature: float) -> None:
if not (isinstance(top_k, int) and top_k >= 0):
raise ValueError("top_k must be a non-negative integer")
if not (0.0 <= top_p <= 1.0):
raise ValueError("top_p must be a float between 0.0 and 1.0")
if not (isinstance(temperature, (int, float)) and temperature >= 0):
raise ValueError("temperature must be a non-negative number")
@dataclass(frozen=True)
class GenerationParams:
"""Immutable value object for sampling hyperparameters."""
@ -76,6 +67,14 @@ class GenerationParams:
temperature: float = 1.0
max_tokens: int = 1024
def __post_init__(self):
if not (isinstance(self.top_k, int) and self.top_k >= 0):
raise ValueError("top_k must be a non-negative integer")
if not (0.0 <= self.top_p <= 1.0):
raise ValueError("top_p must be a float between 0.0 and 1.0")
if not (isinstance(self.temperature, (int, float)) and self.temperature >= 0):
raise ValueError("temperature must be a non-negative number")
class GenerationRequest:
"""Request parameters for text generation."""
@ -97,7 +96,6 @@ class GenerationRequest:
max_tokens=max_len,
)
self.stream = stream
_validate_params(top_k, top_p, temperature)
@property
def top_k(self) -> int:
@ -157,31 +155,32 @@ class InferenceEngine:
top_p: float = 1.0,
top_k: int = 50,
) -> Union[Generator, str, List[str]]:
_validate_params(top_k, top_p, temperature)
params = GenerationParams(
top_k=top_k, top_p=top_p, temperature=temperature, max_tokens=max_tokens
)
is_batch = isinstance(prompt, list)
prompts = prompt if is_batch else [prompt]
if stream:
return self._generate_streaming(
prompts, is_batch, max_tokens, temperature, top_p, top_k
)
return self._generate_streaming(prompts, is_batch, params)
else:
return self._generate_non_streaming(
prompts, is_batch, max_tokens, temperature, top_p, top_k
)
return self._generate_non_streaming(prompts, is_batch, params)
def generate_async(
self,
prompt: str,
params: Optional[GenerationParams] = None,
max_tokens: int = 1024,
temperature: float = 1.0,
top_p: float = 1.0,
top_k: int = 50,
) -> AsyncGenerator[str, None]:
sync_gen = self._generate_streaming(
[prompt], False, max_tokens, temperature, top_p, top_k
if params is None:
params = GenerationParams(
top_k=top_k, top_p=top_p, temperature=temperature, max_tokens=max_tokens
)
sync_gen = self._generate_streaming([prompt], False, params)
async def _agen():
loop = asyncio.get_event_loop()
@ -214,12 +213,7 @@ class InferenceEngine:
)
def _submit_tasks(
self,
prompts: List[str],
max_tokens: int,
temperature: float,
top_p: float,
top_k: int,
self, prompts: List[str], params: GenerationParams
) -> Tuple[GenerateResult, List[str]]:
n = len(prompts)
result = GenerateResult(count=n)
@ -228,10 +222,10 @@ class InferenceEngine:
cb = self._make_callback(result, i)
task_id = self.scheduler.add_task(
prompt=p,
max_tokens=max_tokens,
temperature=temperature,
top_p=top_p,
top_k=top_k,
max_tokens=params.max_tokens,
temperature=params.temperature,
top_p=params.top_p,
top_k=params.top_k,
stream_callback=cb,
)
task_ids.append(task_id)
@ -245,17 +239,9 @@ class InferenceEngine:
return cb
def _generate_streaming(
self,
prompts: List[str],
is_batch: bool,
max_tokens: int,
temperature: float,
top_p: float,
top_k: int,
self, prompts: List[str], is_batch: bool, params: GenerationParams
) -> Generator:
result, task_ids = self._submit_tasks(
prompts, max_tokens, temperature, top_p, top_k
)
result, task_ids = self._submit_tasks(prompts, params)
n = len(prompts)
remaining = n
finished = [False] * n
@ -281,17 +267,9 @@ class InferenceEngine:
return gen()
def _generate_non_streaming(
self,
prompts: List[str],
is_batch: bool,
max_tokens: int,
temperature: float,
top_p: float,
top_k: int,
self, prompts: List[str], is_batch: bool, params: GenerationParams
) -> Union[str, List[str]]:
result, task_ids = self._submit_tasks(
prompts, max_tokens, temperature, top_p, top_k
)
result, task_ids = self._submit_tasks(prompts, params)
result.wait_completion()

View File

@ -4,13 +4,13 @@ AutoModel base class for model loading and saving.
from contextlib import contextmanager
from pathlib import Path
from typing import Self, Type, Union
from typing import Self, Union
import safetensors.torch as st
import torch.nn as nn
from astrai.config import ModelConfig
from astrai.factory import Registry
from astrai.factory import BaseFactory
@contextmanager
@ -39,46 +39,16 @@ def _disable_random_init(enable: bool = True):
setattr(nn.init, name, orig_func)
class AutoModel(nn.Module):
class AutoModel(BaseFactory["AutoModel"], nn.Module):
"""
Autoregressive language model base class.
Provides model loading/saving and generation capabilities.
Provides model loading/saving, registration, and generation.
"""
_registry = Registry()
def __init__(self, config: ModelConfig):
super().__init__()
self.config = config
@classmethod
def register(cls, model_type: str):
"""
Class method decorator to register model type.
Usage:
@AutoModel.register('transformer')
class Transformer(AutoModel):
...
"""
def decorator(sub_cls: Type["AutoModel"]) -> Type["AutoModel"]:
cls._registry.register(model_type.lower(), sub_cls)
return sub_cls
return decorator
@classmethod
def get_model_class(cls, model_type: str) -> Type["AutoModel"]:
"""Get model class by model_type string."""
model_type = model_type.lower()
if not cls._registry.contains(model_type):
available = cls._registry.list_names()
raise ValueError(
f"Unknown model_type: {model_type}. Available: {available}"
)
return cls._registry.get(model_type)
@classmethod
def from_pretrained(
cls,
@ -98,7 +68,7 @@ class AutoModel(nn.Module):
raise FileNotFoundError(f"Config file not found: {config_path}")
model_type = config.model_type or "transformer"
actual_cls = cls.get_model_class(model_type)
actual_cls = AutoModel.get_component_class(model_type)
with _disable_random_init(enable=disable_random_init):
model = actual_cls(config)

View File

@ -69,12 +69,6 @@ class CallbackFactory(BaseFactory[TrainCallback]):
callback = CallbackFactory.create("my_callback", **kwargs)
"""
@classmethod
def _validate_component(cls, callback_cls: type) -> None:
"""Validate that the callback class inherits from TrainCallback."""
if not issubclass(callback_cls, TrainCallback):
raise TypeError(f"{callback_cls.__name__} must inherit from TrainCallback")
@CallbackFactory.register("gradient_clipping")
class GradientClippingCallback(TrainCallback):