143 lines
4.3 KiB
Python
143 lines
4.3 KiB
Python
"""Unit tests for the inference HTTP server."""
|
|
|
|
from unittest.mock import MagicMock
|
|
|
|
import pytest
|
|
|
|
|
|
def test_health_no_model(client, monkeypatch):
|
|
"""GET /health should return 200 even when engine not loaded."""
|
|
monkeypatch.setattr("astrai.inference.server._state.engine", None)
|
|
response = client.get("/health")
|
|
assert response.status_code == 200
|
|
data = response.json()
|
|
assert data["status"] == "ok"
|
|
assert not data["model_loaded"]
|
|
|
|
|
|
def test_health_with_model(client, loaded_model):
|
|
"""GET /health should return 200 when engine is loaded."""
|
|
response = client.get("/health")
|
|
assert response.status_code == 200
|
|
data = response.json()
|
|
assert data["status"] == "ok"
|
|
assert data["model_loaded"] is True
|
|
|
|
|
|
def test_generate_non_stream(client, loaded_model, monkeypatch):
|
|
"""POST /generate with stream=false should return JSON response."""
|
|
response = client.post(
|
|
"/generate",
|
|
params={
|
|
"query": "Hello",
|
|
"temperature": 0.8,
|
|
"top_p": 0.95,
|
|
"top_k": 50,
|
|
"max_len": 100,
|
|
"stream": False,
|
|
},
|
|
)
|
|
assert response.status_code == 200
|
|
data = response.json()
|
|
assert "response" in data
|
|
|
|
|
|
def test_generate_stream(client, loaded_model, monkeypatch):
|
|
"""POST /generate with stream=true should return plain text stream."""
|
|
|
|
async def async_gen():
|
|
yield "chunk1"
|
|
yield "chunk2"
|
|
|
|
mock_engine = loaded_model
|
|
mock_engine.generate_async.return_value = async_gen()
|
|
monkeypatch.setattr("astrai.inference.server._state.engine", mock_engine)
|
|
response = client.post(
|
|
"/generate",
|
|
params={
|
|
"query": "Hello",
|
|
"temperature": 0.8,
|
|
"top_p": 0.95,
|
|
"top_k": 50,
|
|
"max_len": 100,
|
|
"stream": True,
|
|
},
|
|
headers={"Accept": "text/plain"},
|
|
)
|
|
assert response.status_code == 200
|
|
content = response.content.decode("utf-8")
|
|
assert "chunk1" in content
|
|
assert "chunk2" in content
|
|
|
|
|
|
def test_chat_completions_non_stream(client, loaded_model, monkeypatch):
|
|
"""POST /v1/chat/completions with stream=false returns OpenAI-style JSON."""
|
|
|
|
async def async_gen():
|
|
yield "Assistant reply"
|
|
|
|
mock_engine = loaded_model
|
|
mock_engine.generate_async.return_value = async_gen()
|
|
monkeypatch.setattr("astrai.inference.server._state.engine", mock_engine)
|
|
response = client.post(
|
|
"/v1/chat/completions",
|
|
json={
|
|
"messages": [{"role": "user", "content": "Hello"}],
|
|
"temperature": 0.8,
|
|
"max_tokens": 100,
|
|
"stream": False,
|
|
},
|
|
)
|
|
assert response.status_code == 200
|
|
data = response.json()
|
|
assert data["object"] == "chat.completion"
|
|
assert len(data["choices"]) == 1
|
|
assert "usage" in data
|
|
assert "prompt_tokens" in data["usage"]
|
|
|
|
|
|
def test_chat_completions_stream(client, loaded_model, monkeypatch):
|
|
"""POST /v1/chat/completions with stream=true returns SSE stream."""
|
|
|
|
async def async_gen():
|
|
yield "cumulative1"
|
|
yield "cumulative2"
|
|
|
|
mock_engine = loaded_model
|
|
mock_engine.generate_async.return_value = async_gen()
|
|
monkeypatch.setattr("astrai.inference.server._state.engine", mock_engine)
|
|
response = client.post(
|
|
"/v1/chat/completions",
|
|
json={
|
|
"messages": [{"role": "user", "content": "Hello"}],
|
|
"temperature": 0.8,
|
|
"max_tokens": 100,
|
|
"stream": True,
|
|
},
|
|
headers={"Accept": "text/event-stream"},
|
|
)
|
|
assert response.status_code == 200
|
|
lines = [
|
|
line.strip() for line in response.content.decode("utf-8").split("\n") if line
|
|
]
|
|
assert any("cumulative1" in line for line in lines)
|
|
assert any("cumulative2" in line for line in lines)
|
|
assert any("[DONE]" in line for line in lines)
|
|
|
|
|
|
def test_generate_with_history(client, loaded_model, monkeypatch):
|
|
"""POST /generate with history parameter."""
|
|
response = client.post(
|
|
"/generate",
|
|
params={
|
|
"query": "Hi",
|
|
"history": [["user1", "assistant1"], ["user2", "assistant2"]],
|
|
"stream": False,
|
|
},
|
|
)
|
|
assert response.status_code == 200
|
|
|
|
|
|
if __name__ == "__main__":
|
|
pytest.main([__file__, "-v"])
|