AstrAI/tests/inference/test_server.py

143 lines
4.3 KiB
Python

"""Unit tests for the inference HTTP server."""
from unittest.mock import MagicMock
import pytest
def test_health_no_model(client, monkeypatch):
"""GET /health should return 200 even when engine not loaded."""
monkeypatch.setattr("astrai.inference.server._state.engine", None)
response = client.get("/health")
assert response.status_code == 200
data = response.json()
assert data["status"] == "ok"
assert not data["model_loaded"]
def test_health_with_model(client, loaded_model):
"""GET /health should return 200 when engine is loaded."""
response = client.get("/health")
assert response.status_code == 200
data = response.json()
assert data["status"] == "ok"
assert data["model_loaded"] is True
def test_generate_non_stream(client, loaded_model, monkeypatch):
"""POST /generate with stream=false should return JSON response."""
response = client.post(
"/generate",
params={
"query": "Hello",
"temperature": 0.8,
"top_p": 0.95,
"top_k": 50,
"max_len": 100,
"stream": False,
},
)
assert response.status_code == 200
data = response.json()
assert "response" in data
def test_generate_stream(client, loaded_model, monkeypatch):
"""POST /generate with stream=true should return plain text stream."""
async def async_gen():
yield "chunk1"
yield "chunk2"
mock_engine = loaded_model
mock_engine.generate_async.return_value = async_gen()
monkeypatch.setattr("astrai.inference.server._state.engine", mock_engine)
response = client.post(
"/generate",
params={
"query": "Hello",
"temperature": 0.8,
"top_p": 0.95,
"top_k": 50,
"max_len": 100,
"stream": True,
},
headers={"Accept": "text/plain"},
)
assert response.status_code == 200
content = response.content.decode("utf-8")
assert "chunk1" in content
assert "chunk2" in content
def test_chat_completions_non_stream(client, loaded_model, monkeypatch):
"""POST /v1/chat/completions with stream=false returns OpenAI-style JSON."""
async def async_gen():
yield "Assistant reply"
mock_engine = loaded_model
mock_engine.generate_async.return_value = async_gen()
monkeypatch.setattr("astrai.inference.server._state.engine", mock_engine)
response = client.post(
"/v1/chat/completions",
json={
"messages": [{"role": "user", "content": "Hello"}],
"temperature": 0.8,
"max_tokens": 100,
"stream": False,
},
)
assert response.status_code == 200
data = response.json()
assert data["object"] == "chat.completion"
assert len(data["choices"]) == 1
assert "usage" in data
assert "prompt_tokens" in data["usage"]
def test_chat_completions_stream(client, loaded_model, monkeypatch):
"""POST /v1/chat/completions with stream=true returns SSE stream."""
async def async_gen():
yield "cumulative1"
yield "cumulative2"
mock_engine = loaded_model
mock_engine.generate_async.return_value = async_gen()
monkeypatch.setattr("astrai.inference.server._state.engine", mock_engine)
response = client.post(
"/v1/chat/completions",
json={
"messages": [{"role": "user", "content": "Hello"}],
"temperature": 0.8,
"max_tokens": 100,
"stream": True,
},
headers={"Accept": "text/event-stream"},
)
assert response.status_code == 200
lines = [
line.strip() for line in response.content.decode("utf-8").split("\n") if line
]
assert any("cumulative1" in line for line in lines)
assert any("cumulative2" in line for line in lines)
assert any("[DONE]" in line for line in lines)
def test_generate_with_history(client, loaded_model, monkeypatch):
"""POST /generate with history parameter."""
response = client.post(
"/generate",
params={
"query": "Hi",
"history": [["user1", "assistant1"], ["user2", "assistant2"]],
"stream": False,
},
)
assert response.status_code == 200
if __name__ == "__main__":
pytest.main([__file__, "-v"])