diff --git a/astrai/factory.py b/astrai/factory.py index 2fd4819..2a2d7a8 100644 --- a/astrai/factory.py +++ b/astrai/factory.py @@ -155,6 +155,26 @@ class BaseFactory(ABC, Generic[T]): """ pass + @classmethod + def get_component_class(cls, name: str) -> Type[T]: + """Get the registered component class by name without instantiating it. + + Args: + name: Registered name of the component + + Returns: + The component class itself + + Raises: + ValueError: If the component name is not registered + """ + if not cls._registry.contains(name): + raise ValueError( + f"Unknown component: '{name}'. " + f"Supported types: {sorted(cls._registry.list_names())}" + ) + return cls._registry.get(name) + @classmethod def list_registered(cls) -> list: """List all registered component names. diff --git a/astrai/inference/__init__.py b/astrai/inference/__init__.py index 06e8b31..6a2deb3 100644 --- a/astrai/inference/__init__.py +++ b/astrai/inference/__init__.py @@ -15,7 +15,6 @@ from astrai.inference.api import ( MessagesRequest, OpenAIHandler, ProtocolHandler, - SSEBuilder, StopChecker, StreamContext, app, @@ -77,7 +76,6 @@ __all__ = [ "SamplingPipeline", # Protocol "ProtocolHandler", - "SSEBuilder", "StopChecker", "StreamContext", "AnthropicHandler", diff --git a/astrai/inference/api/__init__.py b/astrai/inference/api/__init__.py index 84c4e10..cb1128e 100644 --- a/astrai/inference/api/__init__.py +++ b/astrai/inference/api/__init__.py @@ -4,7 +4,6 @@ from astrai.inference.api.protocol import ( AnthropicHandler, OpenAIHandler, ProtocolHandler, - SSEBuilder, StopChecker, StreamContext, ) @@ -21,7 +20,6 @@ __all__ = [ "AnthropicHandler", "OpenAIHandler", "ProtocolHandler", - "SSEBuilder", "StopChecker", "StreamContext", "AnthropicMessage", diff --git a/astrai/inference/api/protocol.py b/astrai/inference/api/protocol.py index 2689e5e..125ad79 100644 --- a/astrai/inference/api/protocol.py +++ b/astrai/inference/api/protocol.py @@ -14,24 +14,20 @@ from typing import Any, Dict, List, Optional, Union from fastapi.responses import StreamingResponse from pydantic import BaseModel -from astrai.inference.engine import InferenceEngine +from astrai.inference.engine import GenerationParams, InferenceEngine -class SSEBuilder: - """Fluent builder for SSE (Server-Sent Events) formatted chunks.""" +def _sse_event(data: Dict[str, Any], event: Optional[str] = None) -> str: + lines: List[str] = [] + if event: + lines.append(f"event: {event}") + lines.append(f"data: {json.dumps(data, ensure_ascii=False)}") + lines.append("") + return "\n".join(lines) - @staticmethod - def event(data: Dict[str, Any], event: Optional[str] = None) -> str: - lines: List[str] = [] - if event: - lines.append(f"event: {event}") - lines.append(f"data: {json.dumps(data, ensure_ascii=False)}") - lines.append("") - return "\n".join(lines) - @staticmethod - def done() -> str: - return "data: [DONE]\n\n" +def _sse_done() -> str: + return "data: [DONE]\n\n" @dataclass @@ -44,6 +40,8 @@ class StreamContext: prompt_tokens: int completion_tokens: int = 0 accumulated: str = "" + stop_matched: Optional[str] = None + last_yield_trimmed: str = "" class StopChecker: @@ -145,13 +143,13 @@ class ProtocolHandler(ABC): prompt_tokens=self._count_prompt_tokens(), ) - agen = self.engine.generate_async( - prompt=self.build_prompt(), + params = GenerationParams( max_tokens=self.request.max_tokens, temperature=self.request.temperature, top_p=self.request.top_p, top_k=self.request.top_k, ) + agen = self.engine.generate_async(prompt=self.build_prompt(), params=params) if self.request.stream: return self._handle_stream(agen, ctx) @@ -180,7 +178,7 @@ class ProtocolHandler(ABC): for event in self.format_stream_end(ctx): yield event - yield SSEBuilder.done() + yield _sse_done() return StreamingResponse( event_stream(), @@ -230,7 +228,7 @@ class OpenAIHandler(ProtocolHandler): def format_stream_start(self, ctx: StreamContext) -> List[str]: return [ - SSEBuilder.event( + _sse_event( { "id": ctx.resp_id, "object": "chat.completion.chunk", @@ -248,7 +246,7 @@ class OpenAIHandler(ProtocolHandler): ] def format_stream_token(self, ctx: StreamContext, token: str) -> str: - return SSEBuilder.event( + return _sse_event( { "id": ctx.resp_id, "object": "chat.completion.chunk", @@ -262,7 +260,7 @@ class OpenAIHandler(ProtocolHandler): def format_stream_end(self, ctx: StreamContext) -> List[str]: return [ - SSEBuilder.event( + _sse_event( { "id": ctx.resp_id, "object": "chat.completion.chunk", @@ -271,7 +269,7 @@ class OpenAIHandler(ProtocolHandler): "choices": [{"index": 0, "delta": {}, "finish_reason": "stop"}], } ), - SSEBuilder.event( + _sse_event( { "prompt_tokens": ctx.prompt_tokens, "completion_tokens": ctx.completion_tokens, @@ -334,16 +332,16 @@ class AnthropicHandler(ProtocolHandler): if not matched: return None - ctx._stop_matched = matched + ctx.stop_matched = matched trimmed = ctx.accumulated[: ctx.accumulated.rfind(matched)] unyielded = trimmed[len(self._yielded) :] if unyielded: - ctx._last_yield_trimmed = unyielded + ctx.last_yield_trimmed = unyielded return matched def format_stream_start(self, ctx: StreamContext) -> List[str]: return [ - SSEBuilder.event( + _sse_event( { "type": "message_start", "message": { @@ -357,7 +355,7 @@ class AnthropicHandler(ProtocolHandler): }, event="message_start", ), - SSEBuilder.event( + _sse_event( { "type": "content_block_start", "index": 0, @@ -369,7 +367,7 @@ class AnthropicHandler(ProtocolHandler): def format_stream_token(self, ctx: StreamContext, token: str) -> str: self._yielded += token - return SSEBuilder.event( + return _sse_event( { "type": "content_block_delta", "index": 0, @@ -379,12 +377,12 @@ class AnthropicHandler(ProtocolHandler): ) def format_stream_end(self, ctx: StreamContext) -> List[str]: - matched = getattr(ctx, "_stop_matched", None) + matched = ctx.stop_matched events: List[str] = [] - last_yielded = getattr(ctx, "_last_yield_trimmed", "") + last_yielded = ctx.last_yield_trimmed if last_yielded: events.append( - SSEBuilder.event( + _sse_event( { "type": "content_block_delta", "index": 0, @@ -394,13 +392,13 @@ class AnthropicHandler(ProtocolHandler): ) ) events.append( - SSEBuilder.event( + _sse_event( {"type": "content_block_stop", "index": 0}, event="content_block_stop", ) ) events.append( - SSEBuilder.event( + _sse_event( { "type": "message_delta", "delta": { @@ -412,13 +410,13 @@ class AnthropicHandler(ProtocolHandler): event="message_delta", ) ) - events.append(SSEBuilder.event({"type": "message_stop"}, event="message_stop")) + events.append(_sse_event({"type": "message_stop"}, event="message_stop")) return events def format_non_stream_response( self, ctx: StreamContext, content: str ) -> Dict[str, Any]: - matched = getattr(ctx, "_stop_matched", None) + matched = ctx.stop_matched if matched: content = content[: content.rfind(matched)] return { diff --git a/astrai/inference/core/executor.py b/astrai/inference/core/executor.py index 692c4e0..fdabfdd 100644 --- a/astrai/inference/core/executor.py +++ b/astrai/inference/core/executor.py @@ -60,25 +60,6 @@ class Executor: paged_cache=self.page_cache.bind(page_tables, total_len=prompt_len), ) - for i, t in enumerate(tasks): - input_ids[i] = torch.tensor( - t.prompt_ids[start_pos:prompt_len], device=self.device - ) - - task_ids = [t.task_id for t in tasks] - page_tables = self.page_cache.make_table_tensor(task_ids, self.device) - - with torch.inference_mode(): - self.model( - input_ids, - position_ids=torch.arange( - start_pos, prompt_len, dtype=torch.long, device=self.device - ) - .unsqueeze(0) - .expand(batch_sz, -1), - paged_cache=self.page_cache.bind(page_tables, total_len=prompt_len), - ) - def execute_decode(self, tasks: List[Task]) -> List[int]: if not tasks: return [] diff --git a/astrai/inference/engine.py b/astrai/inference/engine.py index 0742ebd..559c510 100644 --- a/astrai/inference/engine.py +++ b/astrai/inference/engine.py @@ -58,15 +58,6 @@ class GenerateResult: return self.results.copy() -def _validate_params(top_k: int, top_p: float, temperature: float) -> None: - if not (isinstance(top_k, int) and top_k >= 0): - raise ValueError("top_k must be a non-negative integer") - if not (0.0 <= top_p <= 1.0): - raise ValueError("top_p must be a float between 0.0 and 1.0") - if not (isinstance(temperature, (int, float)) and temperature >= 0): - raise ValueError("temperature must be a non-negative number") - - @dataclass(frozen=True) class GenerationParams: """Immutable value object for sampling hyperparameters.""" @@ -76,6 +67,14 @@ class GenerationParams: temperature: float = 1.0 max_tokens: int = 1024 + def __post_init__(self): + if not (isinstance(self.top_k, int) and self.top_k >= 0): + raise ValueError("top_k must be a non-negative integer") + if not (0.0 <= self.top_p <= 1.0): + raise ValueError("top_p must be a float between 0.0 and 1.0") + if not (isinstance(self.temperature, (int, float)) and self.temperature >= 0): + raise ValueError("temperature must be a non-negative number") + class GenerationRequest: """Request parameters for text generation.""" @@ -97,7 +96,6 @@ class GenerationRequest: max_tokens=max_len, ) self.stream = stream - _validate_params(top_k, top_p, temperature) @property def top_k(self) -> int: @@ -157,31 +155,32 @@ class InferenceEngine: top_p: float = 1.0, top_k: int = 50, ) -> Union[Generator, str, List[str]]: - _validate_params(top_k, top_p, temperature) + params = GenerationParams( + top_k=top_k, top_p=top_p, temperature=temperature, max_tokens=max_tokens + ) is_batch = isinstance(prompt, list) prompts = prompt if is_batch else [prompt] if stream: - return self._generate_streaming( - prompts, is_batch, max_tokens, temperature, top_p, top_k - ) + return self._generate_streaming(prompts, is_batch, params) else: - return self._generate_non_streaming( - prompts, is_batch, max_tokens, temperature, top_p, top_k - ) + return self._generate_non_streaming(prompts, is_batch, params) def generate_async( self, prompt: str, + params: Optional[GenerationParams] = None, max_tokens: int = 1024, temperature: float = 1.0, top_p: float = 1.0, top_k: int = 50, ) -> AsyncGenerator[str, None]: - sync_gen = self._generate_streaming( - [prompt], False, max_tokens, temperature, top_p, top_k - ) + if params is None: + params = GenerationParams( + top_k=top_k, top_p=top_p, temperature=temperature, max_tokens=max_tokens + ) + sync_gen = self._generate_streaming([prompt], False, params) async def _agen(): loop = asyncio.get_event_loop() @@ -214,12 +213,7 @@ class InferenceEngine: ) def _submit_tasks( - self, - prompts: List[str], - max_tokens: int, - temperature: float, - top_p: float, - top_k: int, + self, prompts: List[str], params: GenerationParams ) -> Tuple[GenerateResult, List[str]]: n = len(prompts) result = GenerateResult(count=n) @@ -228,10 +222,10 @@ class InferenceEngine: cb = self._make_callback(result, i) task_id = self.scheduler.add_task( prompt=p, - max_tokens=max_tokens, - temperature=temperature, - top_p=top_p, - top_k=top_k, + max_tokens=params.max_tokens, + temperature=params.temperature, + top_p=params.top_p, + top_k=params.top_k, stream_callback=cb, ) task_ids.append(task_id) @@ -245,17 +239,9 @@ class InferenceEngine: return cb def _generate_streaming( - self, - prompts: List[str], - is_batch: bool, - max_tokens: int, - temperature: float, - top_p: float, - top_k: int, + self, prompts: List[str], is_batch: bool, params: GenerationParams ) -> Generator: - result, task_ids = self._submit_tasks( - prompts, max_tokens, temperature, top_p, top_k - ) + result, task_ids = self._submit_tasks(prompts, params) n = len(prompts) remaining = n finished = [False] * n @@ -281,17 +267,9 @@ class InferenceEngine: return gen() def _generate_non_streaming( - self, - prompts: List[str], - is_batch: bool, - max_tokens: int, - temperature: float, - top_p: float, - top_k: int, + self, prompts: List[str], is_batch: bool, params: GenerationParams ) -> Union[str, List[str]]: - result, task_ids = self._submit_tasks( - prompts, max_tokens, temperature, top_p, top_k - ) + result, task_ids = self._submit_tasks(prompts, params) result.wait_completion() diff --git a/astrai/model/automodel.py b/astrai/model/automodel.py index 3cd6e8e..d86a523 100644 --- a/astrai/model/automodel.py +++ b/astrai/model/automodel.py @@ -4,13 +4,13 @@ AutoModel base class for model loading and saving. from contextlib import contextmanager from pathlib import Path -from typing import Self, Type, Union +from typing import Self, Union import safetensors.torch as st import torch.nn as nn from astrai.config import ModelConfig -from astrai.factory import Registry +from astrai.factory import BaseFactory @contextmanager @@ -39,46 +39,16 @@ def _disable_random_init(enable: bool = True): setattr(nn.init, name, orig_func) -class AutoModel(nn.Module): +class AutoModel(BaseFactory["AutoModel"], nn.Module): """ Autoregressive language model base class. - Provides model loading/saving and generation capabilities. + Provides model loading/saving, registration, and generation. """ - _registry = Registry() - def __init__(self, config: ModelConfig): super().__init__() self.config = config - @classmethod - def register(cls, model_type: str): - """ - Class method decorator to register model type. - - Usage: - @AutoModel.register('transformer') - class Transformer(AutoModel): - ... - """ - - def decorator(sub_cls: Type["AutoModel"]) -> Type["AutoModel"]: - cls._registry.register(model_type.lower(), sub_cls) - return sub_cls - - return decorator - - @classmethod - def get_model_class(cls, model_type: str) -> Type["AutoModel"]: - """Get model class by model_type string.""" - model_type = model_type.lower() - if not cls._registry.contains(model_type): - available = cls._registry.list_names() - raise ValueError( - f"Unknown model_type: {model_type}. Available: {available}" - ) - return cls._registry.get(model_type) - @classmethod def from_pretrained( cls, @@ -98,7 +68,7 @@ class AutoModel(nn.Module): raise FileNotFoundError(f"Config file not found: {config_path}") model_type = config.model_type or "transformer" - actual_cls = cls.get_model_class(model_type) + actual_cls = AutoModel.get_component_class(model_type) with _disable_random_init(enable=disable_random_init): model = actual_cls(config) diff --git a/astrai/trainer/train_callback.py b/astrai/trainer/train_callback.py index 623cda6..1f18789 100644 --- a/astrai/trainer/train_callback.py +++ b/astrai/trainer/train_callback.py @@ -69,12 +69,6 @@ class CallbackFactory(BaseFactory[TrainCallback]): callback = CallbackFactory.create("my_callback", **kwargs) """ - @classmethod - def _validate_component(cls, callback_cls: type) -> None: - """Validate that the callback class inherits from TrainCallback.""" - if not issubclass(callback_cls, TrainCallback): - raise TypeError(f"{callback_cls.__name__} must inherit from TrainCallback") - @CallbackFactory.register("gradient_clipping") class GradientClippingCallback(TrainCallback):