fix: 修复 remove_task 未释放 KV cache slot 导致第二轮对话死锁
- remove_task() 现在释放 KV cache slot 和 prefix cache 引用 - _refill_active_batch 中 alloc 失败时将剩余 task 推回 waiting_queue - 主循环增加 try/except 异常兜底,发送 _STOP 给所有 task - 重构:server.py 全局变量改为 ServerState 类;automodel.py 使用 Registry 替代裸 dict;合并 TrainContextBuilder 的 with_* 方法到 build()
This commit is contained in:
parent
ffff05b2c6
commit
a6f5ff3b37
|
|
@ -2,7 +2,6 @@
|
||||||
|
|
||||||
import asyncio
|
import asyncio
|
||||||
import gc
|
import gc
|
||||||
import logging
|
|
||||||
import threading
|
import threading
|
||||||
from typing import Any, AsyncGenerator, Dict, Generator, List, Optional, Union
|
from typing import Any, AsyncGenerator, Dict, Generator, List, Optional, Union
|
||||||
|
|
||||||
|
|
@ -12,8 +11,6 @@ import torch.nn as nn
|
||||||
from astrai.inference.scheduler import _STOP, InferenceScheduler
|
from astrai.inference.scheduler import _STOP, InferenceScheduler
|
||||||
from astrai.tokenize import AutoTokenizer
|
from astrai.tokenize import AutoTokenizer
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
|
||||||
|
|
||||||
|
|
||||||
class GenerationRequest:
|
class GenerationRequest:
|
||||||
"""Request parameters for text generation.
|
"""Request parameters for text generation.
|
||||||
|
|
|
||||||
|
|
@ -1,5 +1,6 @@
|
||||||
"""Inference scheduler for single-GPU continuous batching."""
|
"""Inference scheduler for single-GPU continuous batching."""
|
||||||
|
|
||||||
|
import logging
|
||||||
import threading
|
import threading
|
||||||
import time
|
import time
|
||||||
import uuid
|
import uuid
|
||||||
|
|
@ -12,6 +13,8 @@ from torch import Tensor
|
||||||
from astrai.model.automodel import AutoModel
|
from astrai.model.automodel import AutoModel
|
||||||
from astrai.tokenize import AutoTokenizer
|
from astrai.tokenize import AutoTokenizer
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
_STOP = object()
|
_STOP = object()
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -506,9 +509,20 @@ class InferenceScheduler:
|
||||||
task_id: The task to remove.
|
task_id: The task to remove.
|
||||||
"""
|
"""
|
||||||
with self._lock:
|
with self._lock:
|
||||||
|
removed_active = [t for t in self.active_tasks if t.task_id == task_id]
|
||||||
self.waiting_queue = [t for t in self.waiting_queue if t.task_id != task_id]
|
self.waiting_queue = [t for t in self.waiting_queue if t.task_id != task_id]
|
||||||
self.active_tasks = [t for t in self.active_tasks if t.task_id != task_id]
|
self.active_tasks = [t for t in self.active_tasks if t.task_id != task_id]
|
||||||
|
|
||||||
|
for task in removed_active:
|
||||||
|
if task.prefix_len > 0:
|
||||||
|
prefix = tuple(task.prompt_ids[: task.prefix_len])
|
||||||
|
self.prefix_cache.release(prefix)
|
||||||
|
if task.prefix_len < len(task.prompt_ids):
|
||||||
|
self.prefix_cache.release(tuple(task.prompt_ids))
|
||||||
|
if task.slot >= 0:
|
||||||
|
self._free_slot(task.slot)
|
||||||
|
task.slot = -1
|
||||||
|
|
||||||
def _remove_finished_tasks(self) -> None:
|
def _remove_finished_tasks(self) -> None:
|
||||||
"""Removes all finished tasks from the active batch.
|
"""Removes all finished tasks from the active batch.
|
||||||
|
|
||||||
|
|
@ -553,7 +567,7 @@ class InferenceScheduler:
|
||||||
for _ in range(n):
|
for _ in range(n):
|
||||||
to_add.append(self.waiting_queue.pop(0))
|
to_add.append(self.waiting_queue.pop(0))
|
||||||
|
|
||||||
for task in to_add:
|
for i, task in enumerate(to_add):
|
||||||
slot = -1
|
slot = -1
|
||||||
reused = False
|
reused = False
|
||||||
if task.prefix_len > 0:
|
if task.prefix_len > 0:
|
||||||
|
|
@ -564,6 +578,8 @@ class InferenceScheduler:
|
||||||
if slot < 0:
|
if slot < 0:
|
||||||
slot = self._alloc_slot()
|
slot = self._alloc_slot()
|
||||||
if slot < 0:
|
if slot < 0:
|
||||||
|
with self._lock:
|
||||||
|
self.waiting_queue[:0] = to_add[i:]
|
||||||
break
|
break
|
||||||
task.slot = slot
|
task.slot = slot
|
||||||
task.status = TaskStatus.RUNNING
|
task.status = TaskStatus.RUNNING
|
||||||
|
|
@ -712,6 +728,7 @@ class InferenceScheduler:
|
||||||
Decode processes only the largest position group to ensure all
|
Decode processes only the largest position group to ensure all
|
||||||
batched tasks share the same KV cache write position.
|
batched tasks share the same KV cache write position.
|
||||||
"""
|
"""
|
||||||
|
try:
|
||||||
while self._running:
|
while self._running:
|
||||||
self._remove_finished_tasks()
|
self._remove_finished_tasks()
|
||||||
self._refill_active_batch()
|
self._refill_active_batch()
|
||||||
|
|
@ -738,6 +755,15 @@ class InferenceScheduler:
|
||||||
if not self.waiting_queue and len(self.active_tasks) <= 1:
|
if not self.waiting_queue and len(self.active_tasks) <= 1:
|
||||||
self._task_event.wait(timeout=0.005)
|
self._task_event.wait(timeout=0.005)
|
||||||
self._task_event.clear()
|
self._task_event.clear()
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Scheduler loop crashed: {e}", exc_info=True)
|
||||||
|
for task in self.active_tasks:
|
||||||
|
if task.stream_callback:
|
||||||
|
task.stream_callback(_STOP)
|
||||||
|
for task in self.waiting_queue:
|
||||||
|
if task.stream_callback:
|
||||||
|
task.stream_callback(_STOP)
|
||||||
|
raise
|
||||||
|
|
||||||
def start(self) -> None:
|
def start(self) -> None:
|
||||||
"""Starts the background generation loop thread."""
|
"""Starts the background generation loop thread."""
|
||||||
|
|
|
||||||
|
|
@ -23,11 +23,22 @@ from astrai.tokenize import AutoTokenizer
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
_engine: Optional[InferenceEngine] = None
|
|
||||||
_model_param: Optional[Any] = None
|
|
||||||
_project_root = Path(__file__).parent.parent.parent
|
_project_root = Path(__file__).parent.parent.parent
|
||||||
|
|
||||||
_server_config: Dict[str, Any] = {
|
|
||||||
|
class ServerState:
|
||||||
|
"""Encapsulates all server runtime state.
|
||||||
|
|
||||||
|
Attributes:
|
||||||
|
engine: The inference engine instance.
|
||||||
|
model_param: The loaded model.
|
||||||
|
config: Server configuration dict.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
self.engine: Optional[InferenceEngine] = None
|
||||||
|
self.model_param: Optional[Any] = None
|
||||||
|
self.config: Dict[str, Any] = {
|
||||||
"device": "cuda",
|
"device": "cuda",
|
||||||
"dtype": torch.bfloat16,
|
"dtype": torch.bfloat16,
|
||||||
"param_path": None,
|
"param_path": None,
|
||||||
|
|
@ -35,34 +46,38 @@ _server_config: Dict[str, Any] = {
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
_state = ServerState()
|
||||||
|
|
||||||
|
|
||||||
def configure_server(
|
def configure_server(
|
||||||
device: str = "cuda",
|
device: str = "cuda",
|
||||||
dtype: torch.dtype = torch.bfloat16,
|
dtype: torch.dtype = torch.bfloat16,
|
||||||
param_path: Optional[Path] = None,
|
param_path: Optional[Path] = None,
|
||||||
max_batch_size: int = 16,
|
max_batch_size: int = 16,
|
||||||
):
|
):
|
||||||
_server_config["device"] = device
|
_state.config.update(
|
||||||
_server_config["dtype"] = dtype
|
device=device,
|
||||||
_server_config["param_path"] = param_path
|
dtype=dtype,
|
||||||
_server_config["max_batch_size"] = max_batch_size
|
param_path=param_path,
|
||||||
|
max_batch_size=max_batch_size,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@asynccontextmanager
|
@asynccontextmanager
|
||||||
async def lifespan(app: FastAPI):
|
async def lifespan(app: FastAPI):
|
||||||
global _model_param, _engine
|
|
||||||
try:
|
try:
|
||||||
load_model(
|
load_model(
|
||||||
param_path=_server_config["param_path"],
|
param_path=_state.config["param_path"],
|
||||||
device=_server_config["device"],
|
device=_state.config["device"],
|
||||||
dtype=_server_config["dtype"],
|
dtype=_state.config["dtype"],
|
||||||
max_batch_size=_server_config["max_batch_size"],
|
max_batch_size=_state.config["max_batch_size"],
|
||||||
)
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Failed to load model: {e}")
|
logger.error(f"Failed to load model: {e}")
|
||||||
raise
|
raise
|
||||||
yield
|
yield
|
||||||
if _engine:
|
if _state.engine:
|
||||||
_engine.shutdown()
|
_state.engine.shutdown()
|
||||||
logger.info("Inference engine shutdown complete")
|
logger.info("Inference engine shutdown complete")
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -75,25 +90,30 @@ def load_model(
|
||||||
dtype: torch.dtype = torch.bfloat16,
|
dtype: torch.dtype = torch.bfloat16,
|
||||||
max_batch_size: int = 16,
|
max_batch_size: int = 16,
|
||||||
):
|
):
|
||||||
global _model_param, _engine
|
|
||||||
if param_path is None:
|
if param_path is None:
|
||||||
param_path = _project_root / "params"
|
param_path = _project_root / "params"
|
||||||
if not param_path.exists():
|
if not param_path.exists():
|
||||||
raise FileNotFoundError(f"Parameter directory not found: {param_path}")
|
raise FileNotFoundError(f"Parameter directory not found: {param_path}")
|
||||||
|
|
||||||
tokenizer = AutoTokenizer.from_pretrained(param_path)
|
tokenizer = AutoTokenizer.from_pretrained(param_path)
|
||||||
_model_param = AutoModel.from_pretrained(param_path)
|
_state.model_param = AutoModel.from_pretrained(param_path)
|
||||||
_model_param.to(device=device, dtype=dtype)
|
_state.model_param.to(device=device, dtype=dtype)
|
||||||
logger.info(f"Model loaded on {device} with dtype {dtype}")
|
logger.info(f"Model loaded on {device} with dtype {dtype}")
|
||||||
|
|
||||||
_engine = InferenceEngine(
|
_state.engine = InferenceEngine(
|
||||||
model=_model_param,
|
model=_state.model_param,
|
||||||
tokenizer=tokenizer,
|
tokenizer=tokenizer,
|
||||||
max_batch_size=max_batch_size,
|
max_batch_size=max_batch_size,
|
||||||
)
|
)
|
||||||
logger.info(f"Inference engine initialized with max_batch_size={max_batch_size}")
|
logger.info(f"Inference engine initialized with max_batch_size={max_batch_size}")
|
||||||
|
|
||||||
|
|
||||||
|
def _get_engine() -> InferenceEngine:
|
||||||
|
if _state.engine is None:
|
||||||
|
raise HTTPException(status_code=503, detail="Engine not initialized")
|
||||||
|
return _state.engine
|
||||||
|
|
||||||
|
|
||||||
class ChatMessage(BaseModel):
|
class ChatMessage(BaseModel):
|
||||||
role: str
|
role: str
|
||||||
content: str
|
content: str
|
||||||
|
|
@ -121,30 +141,27 @@ class CompletionResponse(BaseModel):
|
||||||
async def health():
|
async def health():
|
||||||
return {
|
return {
|
||||||
"status": "ok",
|
"status": "ok",
|
||||||
"model_loaded": _model_param is not None,
|
"model_loaded": _state.model_param is not None,
|
||||||
"engine_ready": _engine is not None,
|
"engine_ready": _state.engine is not None,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@app.get("/stats")
|
@app.get("/stats")
|
||||||
async def get_stats():
|
async def get_stats():
|
||||||
if _engine is None:
|
return _get_engine().get_stats()
|
||||||
raise HTTPException(status_code=503, detail="Engine not initialized")
|
|
||||||
return _engine.get_stats()
|
|
||||||
|
|
||||||
|
|
||||||
@app.post("/v1/chat/completions", response_model=CompletionResponse)
|
@app.post("/v1/chat/completions", response_model=CompletionResponse)
|
||||||
async def chat_completion(request: ChatCompletionRequest):
|
async def chat_completion(request: ChatCompletionRequest):
|
||||||
if _engine is None:
|
engine = _get_engine()
|
||||||
raise HTTPException(status_code=503, detail="Engine not initialized")
|
|
||||||
|
|
||||||
prompt = _engine.tokenizer.apply_chat_template(
|
prompt = engine.tokenizer.apply_chat_template(
|
||||||
[{"role": m.role, "content": m.content} for m in request.messages],
|
[{"role": m.role, "content": m.content} for m in request.messages],
|
||||||
tokenize=False,
|
tokenize=False,
|
||||||
)
|
)
|
||||||
|
|
||||||
if request.stream:
|
if request.stream:
|
||||||
agen = _engine.generate_async(
|
agen = engine.generate_async(
|
||||||
prompt=prompt,
|
prompt=prompt,
|
||||||
max_tokens=request.max_tokens,
|
max_tokens=request.max_tokens,
|
||||||
temperature=request.temperature,
|
temperature=request.temperature,
|
||||||
|
|
@ -163,7 +180,7 @@ async def chat_completion(request: ChatCompletionRequest):
|
||||||
headers={"Cache-Control": "no-cache", "Connection": "keep-alive"},
|
headers={"Cache-Control": "no-cache", "Connection": "keep-alive"},
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
result = _engine.generate(
|
result = engine.generate(
|
||||||
prompt=prompt,
|
prompt=prompt,
|
||||||
stream=False,
|
stream=False,
|
||||||
max_tokens=request.max_tokens,
|
max_tokens=request.max_tokens,
|
||||||
|
|
@ -198,8 +215,7 @@ async def generate(
|
||||||
max_len: int = 2048,
|
max_len: int = 2048,
|
||||||
stream: bool = False,
|
stream: bool = False,
|
||||||
):
|
):
|
||||||
if _engine is None:
|
engine = _get_engine()
|
||||||
raise HTTPException(status_code=503, detail="Engine not initialized")
|
|
||||||
|
|
||||||
messages = []
|
messages = []
|
||||||
if history:
|
if history:
|
||||||
|
|
@ -209,10 +225,10 @@ async def generate(
|
||||||
messages.append({"role": "assistant", "content": h[1]})
|
messages.append({"role": "assistant", "content": h[1]})
|
||||||
messages.append({"role": "user", "content": query})
|
messages.append({"role": "user", "content": query})
|
||||||
|
|
||||||
prompt = _engine.tokenizer.apply_chat_template(messages, tokenize=False)
|
prompt = engine.tokenizer.apply_chat_template(messages, tokenize=False)
|
||||||
|
|
||||||
if stream:
|
if stream:
|
||||||
agen = _engine.generate_async(
|
agen = engine.generate_async(
|
||||||
prompt=prompt,
|
prompt=prompt,
|
||||||
max_tokens=max_len,
|
max_tokens=max_len,
|
||||||
temperature=temperature,
|
temperature=temperature,
|
||||||
|
|
@ -226,7 +242,7 @@ async def generate(
|
||||||
|
|
||||||
return StreamingResponse(text_stream(), media_type="text/plain")
|
return StreamingResponse(text_stream(), media_type="text/plain")
|
||||||
else:
|
else:
|
||||||
result = _engine.generate(
|
result = engine.generate(
|
||||||
prompt=prompt,
|
prompt=prompt,
|
||||||
stream=False,
|
stream=False,
|
||||||
max_tokens=max_len,
|
max_tokens=max_len,
|
||||||
|
|
|
||||||
|
|
@ -4,12 +4,13 @@ AutoModel base class for model loading and saving.
|
||||||
|
|
||||||
from contextlib import contextmanager
|
from contextlib import contextmanager
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Dict, Self, Type, Union
|
from typing import Self, Type, Union
|
||||||
|
|
||||||
import safetensors.torch as st
|
import safetensors.torch as st
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
|
|
||||||
from astrai.config import ModelConfig
|
from astrai.config import ModelConfig
|
||||||
|
from astrai.factory import Registry
|
||||||
|
|
||||||
|
|
||||||
@contextmanager
|
@contextmanager
|
||||||
|
|
@ -44,8 +45,7 @@ class AutoModel(nn.Module):
|
||||||
Provides model loading/saving and generation capabilities.
|
Provides model loading/saving and generation capabilities.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
# Model registry - stored as class attribute
|
_registry = Registry()
|
||||||
_registry: Dict[str, Type["AutoModel"]] = {}
|
|
||||||
|
|
||||||
def __init__(self, config: ModelConfig):
|
def __init__(self, config: ModelConfig):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
@ -63,7 +63,7 @@ class AutoModel(nn.Module):
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def decorator(sub_cls: Type["AutoModel"]) -> Type["AutoModel"]:
|
def decorator(sub_cls: Type["AutoModel"]) -> Type["AutoModel"]:
|
||||||
cls._registry[model_type.lower()] = sub_cls
|
cls._registry.register(model_type.lower(), sub_cls)
|
||||||
return sub_cls
|
return sub_cls
|
||||||
|
|
||||||
return decorator
|
return decorator
|
||||||
|
|
@ -72,12 +72,12 @@ class AutoModel(nn.Module):
|
||||||
def get_model_class(cls, model_type: str) -> Type["AutoModel"]:
|
def get_model_class(cls, model_type: str) -> Type["AutoModel"]:
|
||||||
"""Get model class by model_type string."""
|
"""Get model class by model_type string."""
|
||||||
model_type = model_type.lower()
|
model_type = model_type.lower()
|
||||||
if model_type not in cls._registry:
|
if not cls._registry.contains(model_type):
|
||||||
available = list(cls._registry.keys())
|
available = cls._registry.list_names()
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"Unknown model_type: {model_type}. Available: {available}"
|
f"Unknown model_type: {model_type}. Available: {available}"
|
||||||
)
|
)
|
||||||
return cls._registry[model_type]
|
return cls._registry.get(model_type)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_pretrained(
|
def from_pretrained(
|
||||||
|
|
@ -96,14 +96,8 @@ class AutoModel(nn.Module):
|
||||||
else:
|
else:
|
||||||
raise FileNotFoundError(f"Config file not found: {config_path}")
|
raise FileNotFoundError(f"Config file not found: {config_path}")
|
||||||
|
|
||||||
# If called from base class, use model_type to determine actual model class
|
|
||||||
if cls is AutoModel:
|
|
||||||
model_type = config.model_type or "transformer"
|
model_type = config.model_type or "transformer"
|
||||||
actual_cls = cls.get_model_class(model_type)
|
actual_cls = cls.get_model_class(model_type)
|
||||||
else:
|
|
||||||
raise ValueError(
|
|
||||||
f"Cannot call from_pretrained() on subclass {cls.__name__}"
|
|
||||||
)
|
|
||||||
|
|
||||||
with _disable_random_init(enable=disable_random_init):
|
with _disable_random_init(enable=disable_random_init):
|
||||||
model = actual_cls(config)
|
model = actual_cls(config)
|
||||||
|
|
|
||||||
|
|
@ -34,66 +34,60 @@ class TrainContext:
|
||||||
class TrainContextBuilder:
|
class TrainContextBuilder:
|
||||||
def __init__(self, config: TrainConfig):
|
def __init__(self, config: TrainConfig):
|
||||||
self.config = config
|
self.config = config
|
||||||
self._context = TrainContext(
|
self._checkpoint: Optional[Checkpoint] = None
|
||||||
model=config.model,
|
|
||||||
|
def with_checkpoint(self, checkpoint: Optional[Checkpoint]) -> Self:
|
||||||
|
self._checkpoint = checkpoint
|
||||||
|
return self
|
||||||
|
|
||||||
|
def build(self) -> TrainContext:
|
||||||
|
context = TrainContext(
|
||||||
|
model=self.config.model,
|
||||||
world_size=get_world_size(),
|
world_size=get_world_size(),
|
||||||
rank=get_rank(),
|
rank=get_rank(),
|
||||||
)
|
)
|
||||||
|
|
||||||
device = get_current_device()
|
device = get_current_device()
|
||||||
self._context.model = self._context.model.to(device=device)
|
context.model = context.model.to(device=device)
|
||||||
|
|
||||||
if self.config.nprocs > 1:
|
if self.config.nprocs > 1 and self.config.parallel_wrapper:
|
||||||
fn = self.config.parallel_wrapper
|
context.model = self.config.parallel_wrapper(context.model)
|
||||||
self._context.model = fn(self._context.model)
|
|
||||||
|
|
||||||
self._context.optimizer = self.config.optimizer_fn(self._context.model)
|
if self._checkpoint is not None:
|
||||||
self._context.scheduler = self.config.scheduler_fn(self._context.optimizer)
|
context.epoch = max(self._checkpoint.epoch, self.config.start_epoch)
|
||||||
|
context.iteration = max(self._checkpoint.iteration, self.config.start_batch)
|
||||||
def with_checkpoint(self, checkpoint: Optional[Checkpoint]) -> Self:
|
context.model.load_state_dict(self._checkpoint.state_dict)
|
||||||
if checkpoint is None:
|
context.checkpoint = self._checkpoint
|
||||||
checkpoint = Checkpoint(
|
|
||||||
state_dict=self._context.model.state_dict(),
|
|
||||||
)
|
|
||||||
else:
|
else:
|
||||||
# resume from the assigned checkpoint or assigned iteration
|
context.checkpoint = Checkpoint(
|
||||||
self._context.epoch = max(checkpoint.epoch, self.config.start_epoch)
|
state_dict=context.model.state_dict(),
|
||||||
self._context.iteration = max(checkpoint.iteration, self.config.start_batch)
|
)
|
||||||
self._context.model.load_state_dict(checkpoint.state_dict)
|
|
||||||
|
|
||||||
self._context.checkpoint = checkpoint
|
context.optimizer = self.config.optimizer_fn(context.model)
|
||||||
return self
|
context.scheduler = self.config.scheduler_fn(context.optimizer)
|
||||||
|
|
||||||
def with_dataloader(self) -> Self:
|
cfg = self.config
|
||||||
# fix: change batch level iteration to sample level offset
|
sampler_offset = context.iteration * cfg.batch_size
|
||||||
config = self.config
|
sampler = ResumableDistributedSampler(
|
||||||
sampler_offset = self._context.iteration * config.batch_size
|
data_source=cfg.dataset,
|
||||||
resumeable_sampler = ResumableDistributedSampler(
|
start_epoch=context.epoch,
|
||||||
data_source=config.dataset,
|
|
||||||
start_epoch=self._context.epoch,
|
|
||||||
start_iter=sampler_offset,
|
start_iter=sampler_offset,
|
||||||
seed=config.random_seed,
|
seed=cfg.random_seed,
|
||||||
|
)
|
||||||
|
context.dataloader = DataLoader(
|
||||||
|
cfg.dataset,
|
||||||
|
batch_size=cfg.batch_size,
|
||||||
|
sampler=sampler,
|
||||||
|
num_workers=cfg.num_workers,
|
||||||
|
pin_memory=cfg.pin_memory,
|
||||||
|
prefetch_factor=cfg.prefetch_factor,
|
||||||
)
|
)
|
||||||
|
|
||||||
dataloader = DataLoader(
|
context.strategy = StrategyFactory.create(
|
||||||
config.dataset,
|
model=context.model,
|
||||||
batch_size=config.batch_size,
|
|
||||||
sampler=resumeable_sampler,
|
|
||||||
num_workers=config.num_workers,
|
|
||||||
pin_memory=config.pin_memory,
|
|
||||||
prefetch_factor=config.prefetch_factor,
|
|
||||||
)
|
|
||||||
self._context.dataloader = dataloader
|
|
||||||
return self
|
|
||||||
|
|
||||||
def with_strategy(self) -> Self:
|
|
||||||
self._context.strategy = StrategyFactory.create(
|
|
||||||
model=self._context.model,
|
|
||||||
train_type=self.config.strategy,
|
train_type=self.config.strategy,
|
||||||
device=get_current_device(),
|
device=device,
|
||||||
**self.config.extra_kwargs,
|
**self.config.extra_kwargs,
|
||||||
)
|
)
|
||||||
return self
|
|
||||||
|
|
||||||
def build(self) -> TrainContext:
|
return context
|
||||||
return self._context
|
|
||||||
|
|
|
||||||
|
|
@ -35,11 +35,7 @@ class Trainer:
|
||||||
|
|
||||||
def _build_context(self, checkpoint: Optional[Checkpoint]) -> TrainContext:
|
def _build_context(self, checkpoint: Optional[Checkpoint]) -> TrainContext:
|
||||||
return (
|
return (
|
||||||
TrainContextBuilder(self.train_config)
|
TrainContextBuilder(self.train_config).with_checkpoint(checkpoint).build()
|
||||||
.with_checkpoint(checkpoint)
|
|
||||||
.with_dataloader()
|
|
||||||
.with_strategy()
|
|
||||||
.build()
|
|
||||||
)
|
)
|
||||||
|
|
||||||
def _call_callbacks(self, method_name: str, context: TrainContext):
|
def _call_callbacks(self, method_name: str, context: TrainContext):
|
||||||
|
|
|
||||||
|
|
@ -53,5 +53,5 @@ def mock_engine():
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def loaded_model(mock_model_param, monkeypatch):
|
def loaded_model(mock_model_param, monkeypatch):
|
||||||
"""Simulate that the model is loaded."""
|
"""Simulate that the model is loaded."""
|
||||||
monkeypatch.setattr("astrai.inference.server._model_param", mock_model_param)
|
monkeypatch.setattr("astrai.inference.server._state.model_param", mock_model_param)
|
||||||
return mock_model_param
|
return mock_model_param
|
||||||
|
|
|
||||||
|
|
@ -5,8 +5,8 @@ import pytest
|
||||||
|
|
||||||
def test_health_no_model(client, monkeypatch):
|
def test_health_no_model(client, monkeypatch):
|
||||||
"""GET /health should return 200 even when model not loaded."""
|
"""GET /health should return 200 even when model not loaded."""
|
||||||
monkeypatch.setattr("astrai.inference.server._model_param", None)
|
monkeypatch.setattr("astrai.inference.server._state.model_param", None)
|
||||||
monkeypatch.setattr("astrai.inference.server._engine", None)
|
monkeypatch.setattr("astrai.inference.server._state.engine", None)
|
||||||
response = client.get("/health")
|
response = client.get("/health")
|
||||||
assert response.status_code == 200
|
assert response.status_code == 200
|
||||||
data = response.json()
|
data = response.json()
|
||||||
|
|
@ -17,7 +17,7 @@ def test_health_no_model(client, monkeypatch):
|
||||||
|
|
||||||
def test_health_with_model(client, loaded_model, mock_engine, monkeypatch):
|
def test_health_with_model(client, loaded_model, mock_engine, monkeypatch):
|
||||||
"""GET /health should return 200 when model is loaded."""
|
"""GET /health should return 200 when model is loaded."""
|
||||||
monkeypatch.setattr("astrai.inference.server._engine", mock_engine)
|
monkeypatch.setattr("astrai.inference.server._state.engine", mock_engine)
|
||||||
response = client.get("/health")
|
response = client.get("/health")
|
||||||
assert response.status_code == 200
|
assert response.status_code == 200
|
||||||
data = response.json()
|
data = response.json()
|
||||||
|
|
@ -28,7 +28,7 @@ def test_health_with_model(client, loaded_model, mock_engine, monkeypatch):
|
||||||
|
|
||||||
def test_generate_non_stream(client, loaded_model, mock_engine, monkeypatch):
|
def test_generate_non_stream(client, loaded_model, mock_engine, monkeypatch):
|
||||||
"""POST /generate with stream=false should return JSON response."""
|
"""POST /generate with stream=false should return JSON response."""
|
||||||
monkeypatch.setattr("astrai.inference.server._engine", mock_engine)
|
monkeypatch.setattr("astrai.inference.server._state.engine", mock_engine)
|
||||||
response = client.post(
|
response = client.post(
|
||||||
"/generate",
|
"/generate",
|
||||||
params={
|
params={
|
||||||
|
|
@ -54,7 +54,7 @@ def test_generate_stream(client, loaded_model, mock_engine, monkeypatch):
|
||||||
yield "chunk2"
|
yield "chunk2"
|
||||||
|
|
||||||
mock_engine.generate.return_value = stream_gen()
|
mock_engine.generate.return_value = stream_gen()
|
||||||
monkeypatch.setattr("astrai.inference.server._engine", mock_engine)
|
monkeypatch.setattr("astrai.inference.server._state.engine", mock_engine)
|
||||||
response = client.post(
|
response = client.post(
|
||||||
"/generate",
|
"/generate",
|
||||||
params={
|
params={
|
||||||
|
|
@ -78,7 +78,7 @@ def test_generate_stream(client, loaded_model, mock_engine, monkeypatch):
|
||||||
def test_chat_completions_non_stream(client, loaded_model, mock_engine, monkeypatch):
|
def test_chat_completions_non_stream(client, loaded_model, mock_engine, monkeypatch):
|
||||||
"""POST /v1/chat/completions with stream=false returns OpenAI‑style JSON."""
|
"""POST /v1/chat/completions with stream=false returns OpenAI‑style JSON."""
|
||||||
mock_engine.generate.return_value = "Assistant reply"
|
mock_engine.generate.return_value = "Assistant reply"
|
||||||
monkeypatch.setattr("astrai.inference.server._engine", mock_engine)
|
monkeypatch.setattr("astrai.inference.server._state.engine", mock_engine)
|
||||||
response = client.post(
|
response = client.post(
|
||||||
"/v1/chat/completions",
|
"/v1/chat/completions",
|
||||||
json={
|
json={
|
||||||
|
|
@ -106,7 +106,7 @@ def test_chat_completions_stream(client, loaded_model, mock_engine, monkeypatch)
|
||||||
yield "[DONE]"
|
yield "[DONE]"
|
||||||
|
|
||||||
mock_engine.generate_async.return_value = async_gen()
|
mock_engine.generate_async.return_value = async_gen()
|
||||||
monkeypatch.setattr("astrai.inference.server._engine", mock_engine)
|
monkeypatch.setattr("astrai.inference.server._state.engine", mock_engine)
|
||||||
response = client.post(
|
response = client.post(
|
||||||
"/v1/chat/completions",
|
"/v1/chat/completions",
|
||||||
json={
|
json={
|
||||||
|
|
@ -132,7 +132,7 @@ def test_chat_completions_stream(client, loaded_model, mock_engine, monkeypatch)
|
||||||
|
|
||||||
def test_generate_with_history(client, loaded_model, mock_engine, monkeypatch):
|
def test_generate_with_history(client, loaded_model, mock_engine, monkeypatch):
|
||||||
"""POST /generate with history parameter."""
|
"""POST /generate with history parameter."""
|
||||||
monkeypatch.setattr("astrai.inference.server._engine", mock_engine)
|
monkeypatch.setattr("astrai.inference.server._state.engine", mock_engine)
|
||||||
response = client.post(
|
response = client.post(
|
||||||
"/generate",
|
"/generate",
|
||||||
params={
|
params={
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue