refactor: 消除多处重复模式,统一工厂和参数传递
- AutoModel 继承 BaseFactory,消除自建 Registry(-30 行) - executor.execute_prefill 删除重复 forward 代码块(bug) - train_callback 移除 Protocol 上矛盾的 issubclass 检查 - engine.py 内部方法统一传 GenerationParams,校验内聚 - protocol.py SSEBuilder 类→函数,handle() 用 GenerationParams - StreamContext 动态属性改为显式 dataclass 字段 - BaseFactory 新增 get_component_class 方法
This commit is contained in:
parent
2196c34c52
commit
18fe6e9339
|
|
@ -155,6 +155,26 @@ class BaseFactory(ABC, Generic[T]):
|
||||||
"""
|
"""
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_component_class(cls, name: str) -> Type[T]:
|
||||||
|
"""Get the registered component class by name without instantiating it.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
name: Registered name of the component
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The component class itself
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: If the component name is not registered
|
||||||
|
"""
|
||||||
|
if not cls._registry.contains(name):
|
||||||
|
raise ValueError(
|
||||||
|
f"Unknown component: '{name}'. "
|
||||||
|
f"Supported types: {sorted(cls._registry.list_names())}"
|
||||||
|
)
|
||||||
|
return cls._registry.get(name)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def list_registered(cls) -> list:
|
def list_registered(cls) -> list:
|
||||||
"""List all registered component names.
|
"""List all registered component names.
|
||||||
|
|
|
||||||
|
|
@ -15,7 +15,6 @@ from astrai.inference.api import (
|
||||||
MessagesRequest,
|
MessagesRequest,
|
||||||
OpenAIHandler,
|
OpenAIHandler,
|
||||||
ProtocolHandler,
|
ProtocolHandler,
|
||||||
SSEBuilder,
|
|
||||||
StopChecker,
|
StopChecker,
|
||||||
StreamContext,
|
StreamContext,
|
||||||
app,
|
app,
|
||||||
|
|
@ -77,7 +76,6 @@ __all__ = [
|
||||||
"SamplingPipeline",
|
"SamplingPipeline",
|
||||||
# Protocol
|
# Protocol
|
||||||
"ProtocolHandler",
|
"ProtocolHandler",
|
||||||
"SSEBuilder",
|
|
||||||
"StopChecker",
|
"StopChecker",
|
||||||
"StreamContext",
|
"StreamContext",
|
||||||
"AnthropicHandler",
|
"AnthropicHandler",
|
||||||
|
|
|
||||||
|
|
@ -4,7 +4,6 @@ from astrai.inference.api.protocol import (
|
||||||
AnthropicHandler,
|
AnthropicHandler,
|
||||||
OpenAIHandler,
|
OpenAIHandler,
|
||||||
ProtocolHandler,
|
ProtocolHandler,
|
||||||
SSEBuilder,
|
|
||||||
StopChecker,
|
StopChecker,
|
||||||
StreamContext,
|
StreamContext,
|
||||||
)
|
)
|
||||||
|
|
@ -21,7 +20,6 @@ __all__ = [
|
||||||
"AnthropicHandler",
|
"AnthropicHandler",
|
||||||
"OpenAIHandler",
|
"OpenAIHandler",
|
||||||
"ProtocolHandler",
|
"ProtocolHandler",
|
||||||
"SSEBuilder",
|
|
||||||
"StopChecker",
|
"StopChecker",
|
||||||
"StreamContext",
|
"StreamContext",
|
||||||
"AnthropicMessage",
|
"AnthropicMessage",
|
||||||
|
|
|
||||||
|
|
@ -14,24 +14,20 @@ from typing import Any, Dict, List, Optional, Union
|
||||||
from fastapi.responses import StreamingResponse
|
from fastapi.responses import StreamingResponse
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
|
|
||||||
from astrai.inference.engine import InferenceEngine
|
from astrai.inference.engine import GenerationParams, InferenceEngine
|
||||||
|
|
||||||
|
|
||||||
class SSEBuilder:
|
def _sse_event(data: Dict[str, Any], event: Optional[str] = None) -> str:
|
||||||
"""Fluent builder for SSE (Server-Sent Events) formatted chunks."""
|
lines: List[str] = []
|
||||||
|
if event:
|
||||||
|
lines.append(f"event: {event}")
|
||||||
|
lines.append(f"data: {json.dumps(data, ensure_ascii=False)}")
|
||||||
|
lines.append("")
|
||||||
|
return "\n".join(lines)
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def event(data: Dict[str, Any], event: Optional[str] = None) -> str:
|
|
||||||
lines: List[str] = []
|
|
||||||
if event:
|
|
||||||
lines.append(f"event: {event}")
|
|
||||||
lines.append(f"data: {json.dumps(data, ensure_ascii=False)}")
|
|
||||||
lines.append("")
|
|
||||||
return "\n".join(lines)
|
|
||||||
|
|
||||||
@staticmethod
|
def _sse_done() -> str:
|
||||||
def done() -> str:
|
return "data: [DONE]\n\n"
|
||||||
return "data: [DONE]\n\n"
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
|
|
@ -44,6 +40,8 @@ class StreamContext:
|
||||||
prompt_tokens: int
|
prompt_tokens: int
|
||||||
completion_tokens: int = 0
|
completion_tokens: int = 0
|
||||||
accumulated: str = ""
|
accumulated: str = ""
|
||||||
|
stop_matched: Optional[str] = None
|
||||||
|
last_yield_trimmed: str = ""
|
||||||
|
|
||||||
|
|
||||||
class StopChecker:
|
class StopChecker:
|
||||||
|
|
@ -145,13 +143,13 @@ class ProtocolHandler(ABC):
|
||||||
prompt_tokens=self._count_prompt_tokens(),
|
prompt_tokens=self._count_prompt_tokens(),
|
||||||
)
|
)
|
||||||
|
|
||||||
agen = self.engine.generate_async(
|
params = GenerationParams(
|
||||||
prompt=self.build_prompt(),
|
|
||||||
max_tokens=self.request.max_tokens,
|
max_tokens=self.request.max_tokens,
|
||||||
temperature=self.request.temperature,
|
temperature=self.request.temperature,
|
||||||
top_p=self.request.top_p,
|
top_p=self.request.top_p,
|
||||||
top_k=self.request.top_k,
|
top_k=self.request.top_k,
|
||||||
)
|
)
|
||||||
|
agen = self.engine.generate_async(prompt=self.build_prompt(), params=params)
|
||||||
|
|
||||||
if self.request.stream:
|
if self.request.stream:
|
||||||
return self._handle_stream(agen, ctx)
|
return self._handle_stream(agen, ctx)
|
||||||
|
|
@ -180,7 +178,7 @@ class ProtocolHandler(ABC):
|
||||||
|
|
||||||
for event in self.format_stream_end(ctx):
|
for event in self.format_stream_end(ctx):
|
||||||
yield event
|
yield event
|
||||||
yield SSEBuilder.done()
|
yield _sse_done()
|
||||||
|
|
||||||
return StreamingResponse(
|
return StreamingResponse(
|
||||||
event_stream(),
|
event_stream(),
|
||||||
|
|
@ -230,7 +228,7 @@ class OpenAIHandler(ProtocolHandler):
|
||||||
|
|
||||||
def format_stream_start(self, ctx: StreamContext) -> List[str]:
|
def format_stream_start(self, ctx: StreamContext) -> List[str]:
|
||||||
return [
|
return [
|
||||||
SSEBuilder.event(
|
_sse_event(
|
||||||
{
|
{
|
||||||
"id": ctx.resp_id,
|
"id": ctx.resp_id,
|
||||||
"object": "chat.completion.chunk",
|
"object": "chat.completion.chunk",
|
||||||
|
|
@ -248,7 +246,7 @@ class OpenAIHandler(ProtocolHandler):
|
||||||
]
|
]
|
||||||
|
|
||||||
def format_stream_token(self, ctx: StreamContext, token: str) -> str:
|
def format_stream_token(self, ctx: StreamContext, token: str) -> str:
|
||||||
return SSEBuilder.event(
|
return _sse_event(
|
||||||
{
|
{
|
||||||
"id": ctx.resp_id,
|
"id": ctx.resp_id,
|
||||||
"object": "chat.completion.chunk",
|
"object": "chat.completion.chunk",
|
||||||
|
|
@ -262,7 +260,7 @@ class OpenAIHandler(ProtocolHandler):
|
||||||
|
|
||||||
def format_stream_end(self, ctx: StreamContext) -> List[str]:
|
def format_stream_end(self, ctx: StreamContext) -> List[str]:
|
||||||
return [
|
return [
|
||||||
SSEBuilder.event(
|
_sse_event(
|
||||||
{
|
{
|
||||||
"id": ctx.resp_id,
|
"id": ctx.resp_id,
|
||||||
"object": "chat.completion.chunk",
|
"object": "chat.completion.chunk",
|
||||||
|
|
@ -271,7 +269,7 @@ class OpenAIHandler(ProtocolHandler):
|
||||||
"choices": [{"index": 0, "delta": {}, "finish_reason": "stop"}],
|
"choices": [{"index": 0, "delta": {}, "finish_reason": "stop"}],
|
||||||
}
|
}
|
||||||
),
|
),
|
||||||
SSEBuilder.event(
|
_sse_event(
|
||||||
{
|
{
|
||||||
"prompt_tokens": ctx.prompt_tokens,
|
"prompt_tokens": ctx.prompt_tokens,
|
||||||
"completion_tokens": ctx.completion_tokens,
|
"completion_tokens": ctx.completion_tokens,
|
||||||
|
|
@ -334,16 +332,16 @@ class AnthropicHandler(ProtocolHandler):
|
||||||
if not matched:
|
if not matched:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
ctx._stop_matched = matched
|
ctx.stop_matched = matched
|
||||||
trimmed = ctx.accumulated[: ctx.accumulated.rfind(matched)]
|
trimmed = ctx.accumulated[: ctx.accumulated.rfind(matched)]
|
||||||
unyielded = trimmed[len(self._yielded) :]
|
unyielded = trimmed[len(self._yielded) :]
|
||||||
if unyielded:
|
if unyielded:
|
||||||
ctx._last_yield_trimmed = unyielded
|
ctx.last_yield_trimmed = unyielded
|
||||||
return matched
|
return matched
|
||||||
|
|
||||||
def format_stream_start(self, ctx: StreamContext) -> List[str]:
|
def format_stream_start(self, ctx: StreamContext) -> List[str]:
|
||||||
return [
|
return [
|
||||||
SSEBuilder.event(
|
_sse_event(
|
||||||
{
|
{
|
||||||
"type": "message_start",
|
"type": "message_start",
|
||||||
"message": {
|
"message": {
|
||||||
|
|
@ -357,7 +355,7 @@ class AnthropicHandler(ProtocolHandler):
|
||||||
},
|
},
|
||||||
event="message_start",
|
event="message_start",
|
||||||
),
|
),
|
||||||
SSEBuilder.event(
|
_sse_event(
|
||||||
{
|
{
|
||||||
"type": "content_block_start",
|
"type": "content_block_start",
|
||||||
"index": 0,
|
"index": 0,
|
||||||
|
|
@ -369,7 +367,7 @@ class AnthropicHandler(ProtocolHandler):
|
||||||
|
|
||||||
def format_stream_token(self, ctx: StreamContext, token: str) -> str:
|
def format_stream_token(self, ctx: StreamContext, token: str) -> str:
|
||||||
self._yielded += token
|
self._yielded += token
|
||||||
return SSEBuilder.event(
|
return _sse_event(
|
||||||
{
|
{
|
||||||
"type": "content_block_delta",
|
"type": "content_block_delta",
|
||||||
"index": 0,
|
"index": 0,
|
||||||
|
|
@ -379,12 +377,12 @@ class AnthropicHandler(ProtocolHandler):
|
||||||
)
|
)
|
||||||
|
|
||||||
def format_stream_end(self, ctx: StreamContext) -> List[str]:
|
def format_stream_end(self, ctx: StreamContext) -> List[str]:
|
||||||
matched = getattr(ctx, "_stop_matched", None)
|
matched = ctx.stop_matched
|
||||||
events: List[str] = []
|
events: List[str] = []
|
||||||
last_yielded = getattr(ctx, "_last_yield_trimmed", "")
|
last_yielded = ctx.last_yield_trimmed
|
||||||
if last_yielded:
|
if last_yielded:
|
||||||
events.append(
|
events.append(
|
||||||
SSEBuilder.event(
|
_sse_event(
|
||||||
{
|
{
|
||||||
"type": "content_block_delta",
|
"type": "content_block_delta",
|
||||||
"index": 0,
|
"index": 0,
|
||||||
|
|
@ -394,13 +392,13 @@ class AnthropicHandler(ProtocolHandler):
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
events.append(
|
events.append(
|
||||||
SSEBuilder.event(
|
_sse_event(
|
||||||
{"type": "content_block_stop", "index": 0},
|
{"type": "content_block_stop", "index": 0},
|
||||||
event="content_block_stop",
|
event="content_block_stop",
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
events.append(
|
events.append(
|
||||||
SSEBuilder.event(
|
_sse_event(
|
||||||
{
|
{
|
||||||
"type": "message_delta",
|
"type": "message_delta",
|
||||||
"delta": {
|
"delta": {
|
||||||
|
|
@ -412,13 +410,13 @@ class AnthropicHandler(ProtocolHandler):
|
||||||
event="message_delta",
|
event="message_delta",
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
events.append(SSEBuilder.event({"type": "message_stop"}, event="message_stop"))
|
events.append(_sse_event({"type": "message_stop"}, event="message_stop"))
|
||||||
return events
|
return events
|
||||||
|
|
||||||
def format_non_stream_response(
|
def format_non_stream_response(
|
||||||
self, ctx: StreamContext, content: str
|
self, ctx: StreamContext, content: str
|
||||||
) -> Dict[str, Any]:
|
) -> Dict[str, Any]:
|
||||||
matched = getattr(ctx, "_stop_matched", None)
|
matched = ctx.stop_matched
|
||||||
if matched:
|
if matched:
|
||||||
content = content[: content.rfind(matched)]
|
content = content[: content.rfind(matched)]
|
||||||
return {
|
return {
|
||||||
|
|
|
||||||
|
|
@ -60,25 +60,6 @@ class Executor:
|
||||||
paged_cache=self.page_cache.bind(page_tables, total_len=prompt_len),
|
paged_cache=self.page_cache.bind(page_tables, total_len=prompt_len),
|
||||||
)
|
)
|
||||||
|
|
||||||
for i, t in enumerate(tasks):
|
|
||||||
input_ids[i] = torch.tensor(
|
|
||||||
t.prompt_ids[start_pos:prompt_len], device=self.device
|
|
||||||
)
|
|
||||||
|
|
||||||
task_ids = [t.task_id for t in tasks]
|
|
||||||
page_tables = self.page_cache.make_table_tensor(task_ids, self.device)
|
|
||||||
|
|
||||||
with torch.inference_mode():
|
|
||||||
self.model(
|
|
||||||
input_ids,
|
|
||||||
position_ids=torch.arange(
|
|
||||||
start_pos, prompt_len, dtype=torch.long, device=self.device
|
|
||||||
)
|
|
||||||
.unsqueeze(0)
|
|
||||||
.expand(batch_sz, -1),
|
|
||||||
paged_cache=self.page_cache.bind(page_tables, total_len=prompt_len),
|
|
||||||
)
|
|
||||||
|
|
||||||
def execute_decode(self, tasks: List[Task]) -> List[int]:
|
def execute_decode(self, tasks: List[Task]) -> List[int]:
|
||||||
if not tasks:
|
if not tasks:
|
||||||
return []
|
return []
|
||||||
|
|
|
||||||
|
|
@ -58,15 +58,6 @@ class GenerateResult:
|
||||||
return self.results.copy()
|
return self.results.copy()
|
||||||
|
|
||||||
|
|
||||||
def _validate_params(top_k: int, top_p: float, temperature: float) -> None:
|
|
||||||
if not (isinstance(top_k, int) and top_k >= 0):
|
|
||||||
raise ValueError("top_k must be a non-negative integer")
|
|
||||||
if not (0.0 <= top_p <= 1.0):
|
|
||||||
raise ValueError("top_p must be a float between 0.0 and 1.0")
|
|
||||||
if not (isinstance(temperature, (int, float)) and temperature >= 0):
|
|
||||||
raise ValueError("temperature must be a non-negative number")
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass(frozen=True)
|
@dataclass(frozen=True)
|
||||||
class GenerationParams:
|
class GenerationParams:
|
||||||
"""Immutable value object for sampling hyperparameters."""
|
"""Immutable value object for sampling hyperparameters."""
|
||||||
|
|
@ -76,6 +67,14 @@ class GenerationParams:
|
||||||
temperature: float = 1.0
|
temperature: float = 1.0
|
||||||
max_tokens: int = 1024
|
max_tokens: int = 1024
|
||||||
|
|
||||||
|
def __post_init__(self):
|
||||||
|
if not (isinstance(self.top_k, int) and self.top_k >= 0):
|
||||||
|
raise ValueError("top_k must be a non-negative integer")
|
||||||
|
if not (0.0 <= self.top_p <= 1.0):
|
||||||
|
raise ValueError("top_p must be a float between 0.0 and 1.0")
|
||||||
|
if not (isinstance(self.temperature, (int, float)) and self.temperature >= 0):
|
||||||
|
raise ValueError("temperature must be a non-negative number")
|
||||||
|
|
||||||
|
|
||||||
class GenerationRequest:
|
class GenerationRequest:
|
||||||
"""Request parameters for text generation."""
|
"""Request parameters for text generation."""
|
||||||
|
|
@ -97,7 +96,6 @@ class GenerationRequest:
|
||||||
max_tokens=max_len,
|
max_tokens=max_len,
|
||||||
)
|
)
|
||||||
self.stream = stream
|
self.stream = stream
|
||||||
_validate_params(top_k, top_p, temperature)
|
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def top_k(self) -> int:
|
def top_k(self) -> int:
|
||||||
|
|
@ -157,31 +155,32 @@ class InferenceEngine:
|
||||||
top_p: float = 1.0,
|
top_p: float = 1.0,
|
||||||
top_k: int = 50,
|
top_k: int = 50,
|
||||||
) -> Union[Generator, str, List[str]]:
|
) -> Union[Generator, str, List[str]]:
|
||||||
_validate_params(top_k, top_p, temperature)
|
params = GenerationParams(
|
||||||
|
top_k=top_k, top_p=top_p, temperature=temperature, max_tokens=max_tokens
|
||||||
|
)
|
||||||
|
|
||||||
is_batch = isinstance(prompt, list)
|
is_batch = isinstance(prompt, list)
|
||||||
prompts = prompt if is_batch else [prompt]
|
prompts = prompt if is_batch else [prompt]
|
||||||
|
|
||||||
if stream:
|
if stream:
|
||||||
return self._generate_streaming(
|
return self._generate_streaming(prompts, is_batch, params)
|
||||||
prompts, is_batch, max_tokens, temperature, top_p, top_k
|
|
||||||
)
|
|
||||||
else:
|
else:
|
||||||
return self._generate_non_streaming(
|
return self._generate_non_streaming(prompts, is_batch, params)
|
||||||
prompts, is_batch, max_tokens, temperature, top_p, top_k
|
|
||||||
)
|
|
||||||
|
|
||||||
def generate_async(
|
def generate_async(
|
||||||
self,
|
self,
|
||||||
prompt: str,
|
prompt: str,
|
||||||
|
params: Optional[GenerationParams] = None,
|
||||||
max_tokens: int = 1024,
|
max_tokens: int = 1024,
|
||||||
temperature: float = 1.0,
|
temperature: float = 1.0,
|
||||||
top_p: float = 1.0,
|
top_p: float = 1.0,
|
||||||
top_k: int = 50,
|
top_k: int = 50,
|
||||||
) -> AsyncGenerator[str, None]:
|
) -> AsyncGenerator[str, None]:
|
||||||
sync_gen = self._generate_streaming(
|
if params is None:
|
||||||
[prompt], False, max_tokens, temperature, top_p, top_k
|
params = GenerationParams(
|
||||||
)
|
top_k=top_k, top_p=top_p, temperature=temperature, max_tokens=max_tokens
|
||||||
|
)
|
||||||
|
sync_gen = self._generate_streaming([prompt], False, params)
|
||||||
|
|
||||||
async def _agen():
|
async def _agen():
|
||||||
loop = asyncio.get_event_loop()
|
loop = asyncio.get_event_loop()
|
||||||
|
|
@ -214,12 +213,7 @@ class InferenceEngine:
|
||||||
)
|
)
|
||||||
|
|
||||||
def _submit_tasks(
|
def _submit_tasks(
|
||||||
self,
|
self, prompts: List[str], params: GenerationParams
|
||||||
prompts: List[str],
|
|
||||||
max_tokens: int,
|
|
||||||
temperature: float,
|
|
||||||
top_p: float,
|
|
||||||
top_k: int,
|
|
||||||
) -> Tuple[GenerateResult, List[str]]:
|
) -> Tuple[GenerateResult, List[str]]:
|
||||||
n = len(prompts)
|
n = len(prompts)
|
||||||
result = GenerateResult(count=n)
|
result = GenerateResult(count=n)
|
||||||
|
|
@ -228,10 +222,10 @@ class InferenceEngine:
|
||||||
cb = self._make_callback(result, i)
|
cb = self._make_callback(result, i)
|
||||||
task_id = self.scheduler.add_task(
|
task_id = self.scheduler.add_task(
|
||||||
prompt=p,
|
prompt=p,
|
||||||
max_tokens=max_tokens,
|
max_tokens=params.max_tokens,
|
||||||
temperature=temperature,
|
temperature=params.temperature,
|
||||||
top_p=top_p,
|
top_p=params.top_p,
|
||||||
top_k=top_k,
|
top_k=params.top_k,
|
||||||
stream_callback=cb,
|
stream_callback=cb,
|
||||||
)
|
)
|
||||||
task_ids.append(task_id)
|
task_ids.append(task_id)
|
||||||
|
|
@ -245,17 +239,9 @@ class InferenceEngine:
|
||||||
return cb
|
return cb
|
||||||
|
|
||||||
def _generate_streaming(
|
def _generate_streaming(
|
||||||
self,
|
self, prompts: List[str], is_batch: bool, params: GenerationParams
|
||||||
prompts: List[str],
|
|
||||||
is_batch: bool,
|
|
||||||
max_tokens: int,
|
|
||||||
temperature: float,
|
|
||||||
top_p: float,
|
|
||||||
top_k: int,
|
|
||||||
) -> Generator:
|
) -> Generator:
|
||||||
result, task_ids = self._submit_tasks(
|
result, task_ids = self._submit_tasks(prompts, params)
|
||||||
prompts, max_tokens, temperature, top_p, top_k
|
|
||||||
)
|
|
||||||
n = len(prompts)
|
n = len(prompts)
|
||||||
remaining = n
|
remaining = n
|
||||||
finished = [False] * n
|
finished = [False] * n
|
||||||
|
|
@ -281,17 +267,9 @@ class InferenceEngine:
|
||||||
return gen()
|
return gen()
|
||||||
|
|
||||||
def _generate_non_streaming(
|
def _generate_non_streaming(
|
||||||
self,
|
self, prompts: List[str], is_batch: bool, params: GenerationParams
|
||||||
prompts: List[str],
|
|
||||||
is_batch: bool,
|
|
||||||
max_tokens: int,
|
|
||||||
temperature: float,
|
|
||||||
top_p: float,
|
|
||||||
top_k: int,
|
|
||||||
) -> Union[str, List[str]]:
|
) -> Union[str, List[str]]:
|
||||||
result, task_ids = self._submit_tasks(
|
result, task_ids = self._submit_tasks(prompts, params)
|
||||||
prompts, max_tokens, temperature, top_p, top_k
|
|
||||||
)
|
|
||||||
|
|
||||||
result.wait_completion()
|
result.wait_completion()
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -4,13 +4,13 @@ AutoModel base class for model loading and saving.
|
||||||
|
|
||||||
from contextlib import contextmanager
|
from contextlib import contextmanager
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Self, Type, Union
|
from typing import Self, Union
|
||||||
|
|
||||||
import safetensors.torch as st
|
import safetensors.torch as st
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
|
|
||||||
from astrai.config import ModelConfig
|
from astrai.config import ModelConfig
|
||||||
from astrai.factory import Registry
|
from astrai.factory import BaseFactory
|
||||||
|
|
||||||
|
|
||||||
@contextmanager
|
@contextmanager
|
||||||
|
|
@ -39,46 +39,16 @@ def _disable_random_init(enable: bool = True):
|
||||||
setattr(nn.init, name, orig_func)
|
setattr(nn.init, name, orig_func)
|
||||||
|
|
||||||
|
|
||||||
class AutoModel(nn.Module):
|
class AutoModel(BaseFactory["AutoModel"], nn.Module):
|
||||||
"""
|
"""
|
||||||
Autoregressive language model base class.
|
Autoregressive language model base class.
|
||||||
Provides model loading/saving and generation capabilities.
|
Provides model loading/saving, registration, and generation.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
_registry = Registry()
|
|
||||||
|
|
||||||
def __init__(self, config: ModelConfig):
|
def __init__(self, config: ModelConfig):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.config = config
|
self.config = config
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def register(cls, model_type: str):
|
|
||||||
"""
|
|
||||||
Class method decorator to register model type.
|
|
||||||
|
|
||||||
Usage:
|
|
||||||
@AutoModel.register('transformer')
|
|
||||||
class Transformer(AutoModel):
|
|
||||||
...
|
|
||||||
"""
|
|
||||||
|
|
||||||
def decorator(sub_cls: Type["AutoModel"]) -> Type["AutoModel"]:
|
|
||||||
cls._registry.register(model_type.lower(), sub_cls)
|
|
||||||
return sub_cls
|
|
||||||
|
|
||||||
return decorator
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def get_model_class(cls, model_type: str) -> Type["AutoModel"]:
|
|
||||||
"""Get model class by model_type string."""
|
|
||||||
model_type = model_type.lower()
|
|
||||||
if not cls._registry.contains(model_type):
|
|
||||||
available = cls._registry.list_names()
|
|
||||||
raise ValueError(
|
|
||||||
f"Unknown model_type: {model_type}. Available: {available}"
|
|
||||||
)
|
|
||||||
return cls._registry.get(model_type)
|
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_pretrained(
|
def from_pretrained(
|
||||||
cls,
|
cls,
|
||||||
|
|
@ -98,7 +68,7 @@ class AutoModel(nn.Module):
|
||||||
raise FileNotFoundError(f"Config file not found: {config_path}")
|
raise FileNotFoundError(f"Config file not found: {config_path}")
|
||||||
|
|
||||||
model_type = config.model_type or "transformer"
|
model_type = config.model_type or "transformer"
|
||||||
actual_cls = cls.get_model_class(model_type)
|
actual_cls = AutoModel.get_component_class(model_type)
|
||||||
|
|
||||||
with _disable_random_init(enable=disable_random_init):
|
with _disable_random_init(enable=disable_random_init):
|
||||||
model = actual_cls(config)
|
model = actual_cls(config)
|
||||||
|
|
|
||||||
|
|
@ -69,12 +69,6 @@ class CallbackFactory(BaseFactory[TrainCallback]):
|
||||||
callback = CallbackFactory.create("my_callback", **kwargs)
|
callback = CallbackFactory.create("my_callback", **kwargs)
|
||||||
"""
|
"""
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def _validate_component(cls, callback_cls: type) -> None:
|
|
||||||
"""Validate that the callback class inherits from TrainCallback."""
|
|
||||||
if not issubclass(callback_cls, TrainCallback):
|
|
||||||
raise TypeError(f"{callback_cls.__name__} must inherit from TrainCallback")
|
|
||||||
|
|
||||||
|
|
||||||
@CallbackFactory.register("gradient_clipping")
|
@CallbackFactory.register("gradient_clipping")
|
||||||
class GradientClippingCallback(TrainCallback):
|
class GradientClippingCallback(TrainCallback):
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue