288 lines
10 KiB
Python
288 lines
10 KiB
Python
"""Unit tests for protocol builders, StopChecker, GenContext, StopInfo."""
|
|
|
|
import json
|
|
from unittest.mock import MagicMock
|
|
|
|
import pytest
|
|
|
|
from astrai.inference.api.anthropic import AnthropicResponseBuilder
|
|
from astrai.inference.api.openai import OpenAIResponseBuilder
|
|
from astrai.inference.api.protocol import GenContext, StopChecker, StopInfo
|
|
from astrai.inference.engine import GenerationRequest
|
|
|
|
|
|
def _make_ctx(**kwargs):
|
|
defaults = {
|
|
"resp_id": "test-123",
|
|
"created": 1000,
|
|
"model": "test-model",
|
|
"prompt_tokens": 10,
|
|
"completion_tokens": 5,
|
|
}
|
|
defaults.update(kwargs)
|
|
return GenContext(**defaults)
|
|
|
|
|
|
def _sse_payloads(events):
|
|
payloads = []
|
|
for chunk in events:
|
|
for line in chunk.strip().split("\n"):
|
|
if line.startswith("data: "):
|
|
try:
|
|
payloads.append(json.loads(line[6:]))
|
|
except json.JSONDecodeError:
|
|
pass
|
|
return payloads
|
|
|
|
|
|
class TestStopChecker:
|
|
def test_check_finds_match(self):
|
|
sc = StopChecker(["stop", "end"])
|
|
assert sc.check("hello stop world") == "stop"
|
|
|
|
def test_check_returns_none_when_no_match(self):
|
|
sc = StopChecker(["stop"])
|
|
assert sc.check("hello world") is None
|
|
|
|
def test_check_empty_sequences(self):
|
|
sc = StopChecker([])
|
|
assert sc.check("hello") is None
|
|
|
|
|
|
class TestGenContext:
|
|
def test_defaults(self):
|
|
ctx = GenContext(resp_id="a", created=1, model="m", prompt_tokens=10)
|
|
assert ctx.completion_tokens == 0
|
|
|
|
def test_fields_mutable(self):
|
|
ctx = GenContext(resp_id="a", created=1, model="m", prompt_tokens=10)
|
|
ctx.completion_tokens = 42
|
|
assert ctx.completion_tokens == 42
|
|
|
|
|
|
class TestStopInfo:
|
|
def test_defaults(self):
|
|
s = StopInfo()
|
|
assert s.matched is None
|
|
assert s.body == ""
|
|
assert s.yielded == ""
|
|
|
|
def test_with_values(self):
|
|
s = StopInfo(matched="stop", body="hello stop", yielded="hello ")
|
|
assert s.matched == "stop"
|
|
assert s.body == "hello stop"
|
|
assert s.yielded == "hello "
|
|
|
|
|
|
class TestOpenAIResponseBuilder:
|
|
@pytest.fixture
|
|
def builder(self):
|
|
builder = OpenAIResponseBuilder()
|
|
req = MagicMock()
|
|
req.messages = [MagicMock(role="user", content="Hello")]
|
|
req.stop = None
|
|
req.model = "astrai"
|
|
engine = MagicMock()
|
|
engine.tokenizer.apply_chat_template.return_value = "Hello"
|
|
builder.prepare(req, engine)
|
|
return builder
|
|
|
|
def test_prepare_returns_prompt_ctx_stops(self, builder):
|
|
req = MagicMock()
|
|
req.messages = [MagicMock(role="user", content="Hi")]
|
|
req.stop = ["END"]
|
|
req.model = "gpt"
|
|
engine = MagicMock()
|
|
engine.tokenizer.apply_chat_template.return_value = "Hi"
|
|
prompt, ctx, stops = builder.prepare(req, engine)
|
|
assert prompt == "Hi"
|
|
assert ctx.model == "gpt"
|
|
assert ctx.prompt_tokens == 0
|
|
assert stops == ["END"]
|
|
|
|
def test_prepare_no_stop_returns_empty_list(self, builder):
|
|
req = MagicMock()
|
|
req.messages = []
|
|
req.stop = None
|
|
req.model = "x"
|
|
engine = MagicMock()
|
|
engine.tokenizer.apply_chat_template.return_value = ""
|
|
_, _, stops = builder.prepare(req, engine)
|
|
assert stops == []
|
|
|
|
def test_format_stream_start(self, builder):
|
|
ctx = _make_ctx()
|
|
events = builder.format_stream_start(ctx)
|
|
payloads = _sse_payloads(events)
|
|
assert len(payloads) == 1
|
|
p = payloads[0]
|
|
assert p["object"] == "chat.completion.chunk"
|
|
assert p["choices"][0]["delta"]["role"] == "assistant"
|
|
assert p["choices"][0]["finish_reason"] is None
|
|
|
|
def test_format_chunk(self, builder):
|
|
ctx = _make_ctx()
|
|
event = builder.format_chunk("hello")
|
|
payload = json.loads(event.split("data: ", 1)[1])
|
|
assert payload["choices"][0]["delta"]["content"] == "hello"
|
|
assert payload["choices"][0]["finish_reason"] is None
|
|
|
|
def test_format_stream_end(self, builder):
|
|
ctx = _make_ctx(completion_tokens=5)
|
|
stop = StopInfo(matched="stop")
|
|
events = builder.format_stream_end(ctx, stop)
|
|
payloads = _sse_payloads(events)
|
|
finish = payloads[0]
|
|
assert finish["choices"][0]["finish_reason"] == "stop"
|
|
usage = payloads[1]
|
|
assert usage["completion_tokens"] == 5
|
|
assert usage["total_tokens"] == 15
|
|
|
|
def test_format_response(self, builder):
|
|
ctx = _make_ctx()
|
|
stop = StopInfo()
|
|
resp = builder.format_response(ctx, "hello", stop)
|
|
assert resp["object"] == "chat.completion"
|
|
assert resp["choices"][0]["message"]["content"] == "hello"
|
|
assert resp["usage"]["prompt_tokens"] == 10
|
|
|
|
|
|
class TestAnthropicResponseBuilder:
|
|
@pytest.fixture
|
|
def builder(self):
|
|
builder = AnthropicResponseBuilder()
|
|
req = MagicMock()
|
|
req.messages = [MagicMock(role="user", content="Hello")]
|
|
req.model = "claude"
|
|
engine = MagicMock()
|
|
engine.tokenizer.apply_chat_template.return_value = "Hello"
|
|
req.system = None
|
|
builder.prepare(req, engine)
|
|
return builder
|
|
|
|
def test_prepare_messages(self, builder):
|
|
req = MagicMock()
|
|
req.messages = [MagicMock(role="user", content="Hi")]
|
|
req.model = "claude"
|
|
req.system = None
|
|
req.stop_sequences = None
|
|
engine = MagicMock()
|
|
engine.tokenizer.apply_chat_template.return_value = "Hi"
|
|
prompt, ctx, stops = builder.prepare(req, engine)
|
|
assert prompt == "Hi"
|
|
assert stops == []
|
|
|
|
def test_prepare_with_stop_sequences(self, builder):
|
|
req = MagicMock()
|
|
req.messages = []
|
|
req.model = "x"
|
|
req.stop_sequences = ["stop", "end"]
|
|
req.system = None
|
|
engine = MagicMock()
|
|
engine.tokenizer.apply_chat_template.return_value = ""
|
|
_, _, stops = builder.prepare(req, engine)
|
|
assert stops == ["stop", "end"]
|
|
|
|
def test_format_stream_start(self, builder):
|
|
ctx = _make_ctx(prompt_tokens=3)
|
|
events = builder.format_stream_start(ctx)
|
|
payloads = _sse_payloads(events)
|
|
assert len(payloads) == 2
|
|
assert payloads[0]["type"] == "message_start"
|
|
assert payloads[0]["message"]["usage"]["input_tokens"] == 3
|
|
assert payloads[1]["type"] == "content_block_start"
|
|
|
|
def test_format_chunk(self, builder):
|
|
event = builder.format_chunk("tok")
|
|
payload = json.loads(event.split("data: ", 1)[1])
|
|
assert payload["type"] == "content_block_delta"
|
|
assert payload["delta"]["text"] == "tok"
|
|
|
|
def test_format_stream_end_no_stop(self, builder):
|
|
ctx = _make_ctx(completion_tokens=3)
|
|
stop = StopInfo()
|
|
events = builder.format_stream_end(ctx, stop)
|
|
payloads = _sse_payloads(events)
|
|
# content_block_stop, message_delta, message_stop
|
|
types = [p["type"] for p in payloads]
|
|
assert types == ["content_block_stop", "message_delta", "message_stop"]
|
|
assert payloads[1]["delta"]["stop_reason"] == "end_turn"
|
|
|
|
def test_format_stream_end_with_stop_trims_and_emits_remaining(self, builder):
|
|
ctx = _make_ctx(completion_tokens=7)
|
|
stop = StopInfo(
|
|
matched="END",
|
|
body="Hello world END extra",
|
|
yielded="Hello ",
|
|
)
|
|
events = builder.format_stream_end(ctx, stop)
|
|
payloads = _sse_payloads(events)
|
|
# unyielded delta, content_block_stop, message_delta, message_stop
|
|
types = [p["type"] for p in payloads]
|
|
assert types == [
|
|
"content_block_delta",
|
|
"content_block_stop",
|
|
"message_delta",
|
|
"message_stop",
|
|
]
|
|
assert payloads[0]["delta"]["text"] == "world "
|
|
assert payloads[2]["delta"]["stop_reason"] == "stop_sequence"
|
|
assert payloads[2]["delta"]["stop_sequence"] == "END"
|
|
|
|
def test_format_stream_end_stop_trimmed_already_yielded(self, builder):
|
|
ctx = _make_ctx()
|
|
stop = StopInfo(
|
|
matched="END",
|
|
body="Hello END",
|
|
yielded="Hello ",
|
|
)
|
|
events = builder.format_stream_end(ctx, stop)
|
|
payloads = _sse_payloads(events)
|
|
# No unyielded delta (everything already sent)
|
|
types = [p["type"] for p in payloads]
|
|
assert types == ["content_block_stop", "message_delta", "message_stop"]
|
|
|
|
def test_format_response_with_stop_trims_content(self, builder):
|
|
ctx = _make_ctx()
|
|
stop = StopInfo(matched="STOP", body="text STOP extra", yielded="text ")
|
|
resp = builder.format_response(ctx, "text STOP extra", stop)
|
|
assert resp["content"][0]["text"] == "text "
|
|
assert resp["stop_reason"] == "stop_sequence"
|
|
assert resp["stop_sequence"] == "STOP"
|
|
|
|
def test_format_response_no_stop(self, builder):
|
|
ctx = _make_ctx()
|
|
stop = StopInfo()
|
|
resp = builder.format_response(ctx, "full text", stop)
|
|
assert resp["content"][0]["text"] == "full text"
|
|
assert resp["stop_reason"] == "end_turn"
|
|
|
|
|
|
class TestGenerationRequestValidation:
|
|
def test_valid_params(self):
|
|
gr = GenerationRequest(
|
|
messages=[{"role": "user", "content": "hi"}],
|
|
top_k=50,
|
|
top_p=0.9,
|
|
temperature=0.7,
|
|
)
|
|
assert gr.top_k == 50
|
|
|
|
def test_invalid_top_p_raises(self):
|
|
with pytest.raises(ValueError, match="top_p"):
|
|
GenerationRequest(messages=[{"role": "user", "content": "hi"}], top_p=1.5)
|
|
|
|
def test_invalid_top_k_raises(self):
|
|
with pytest.raises(ValueError, match="top_k"):
|
|
GenerationRequest(messages=[{"role": "user", "content": "hi"}], top_k=-1)
|
|
|
|
def test_invalid_temperature_raises(self):
|
|
with pytest.raises(ValueError, match="temperature"):
|
|
GenerationRequest(
|
|
messages=[{"role": "user", "content": "hi"}], temperature=-0.1
|
|
)
|
|
|
|
def test_top_k_zero_valid(self):
|
|
gr = GenerationRequest(messages=[{"role": "user", "content": "hi"}], top_k=0)
|
|
assert gr.top_k == 0
|