test: 补充推理协议层单测覆盖
- StopChecker、GenContext、StopInfo 单测 - OpenAIResponseBuilder / AnthropicResponseBuilder 全部方法 - Anthropic 停止序列裁剪逻辑(含 unyielded 边界) - GenerationRequest 参数校验含负值边界 - Scheduler prefill 短路验证
This commit is contained in:
parent
47c37e4876
commit
94d6e713e9
|
|
@ -0,0 +1,287 @@
|
|||
"""Unit tests for protocol builders, StopChecker, GenContext, StopInfo."""
|
||||
|
||||
import json
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import pytest
|
||||
|
||||
from astrai.inference.api.anthropic import AnthropicResponseBuilder
|
||||
from astrai.inference.api.openai import OpenAIResponseBuilder
|
||||
from astrai.inference.api.protocol import GenContext, StopChecker, StopInfo
|
||||
from astrai.inference.engine import GenerationRequest
|
||||
|
||||
|
||||
def _make_ctx(**kwargs):
|
||||
defaults = {
|
||||
"resp_id": "test-123",
|
||||
"created": 1000,
|
||||
"model": "test-model",
|
||||
"prompt_tokens": 10,
|
||||
"completion_tokens": 5,
|
||||
}
|
||||
defaults.update(kwargs)
|
||||
return GenContext(**defaults)
|
||||
|
||||
|
||||
def _sse_payloads(events):
|
||||
payloads = []
|
||||
for chunk in events:
|
||||
for line in chunk.strip().split("\n"):
|
||||
if line.startswith("data: "):
|
||||
try:
|
||||
payloads.append(json.loads(line[6:]))
|
||||
except json.JSONDecodeError:
|
||||
pass
|
||||
return payloads
|
||||
|
||||
|
||||
class TestStopChecker:
|
||||
def test_check_finds_match(self):
|
||||
sc = StopChecker(["stop", "end"])
|
||||
assert sc.check("hello stop world") == "stop"
|
||||
|
||||
def test_check_returns_none_when_no_match(self):
|
||||
sc = StopChecker(["stop"])
|
||||
assert sc.check("hello world") is None
|
||||
|
||||
def test_check_empty_sequences(self):
|
||||
sc = StopChecker([])
|
||||
assert sc.check("hello") is None
|
||||
|
||||
|
||||
class TestGenContext:
|
||||
def test_defaults(self):
|
||||
ctx = GenContext(resp_id="a", created=1, model="m", prompt_tokens=10)
|
||||
assert ctx.completion_tokens == 0
|
||||
|
||||
def test_fields_mutable(self):
|
||||
ctx = GenContext(resp_id="a", created=1, model="m", prompt_tokens=10)
|
||||
ctx.completion_tokens = 42
|
||||
assert ctx.completion_tokens == 42
|
||||
|
||||
|
||||
class TestStopInfo:
|
||||
def test_defaults(self):
|
||||
s = StopInfo()
|
||||
assert s.matched is None
|
||||
assert s.body == ""
|
||||
assert s.yielded == ""
|
||||
|
||||
def test_with_values(self):
|
||||
s = StopInfo(matched="stop", body="hello stop", yielded="hello ")
|
||||
assert s.matched == "stop"
|
||||
assert s.body == "hello stop"
|
||||
assert s.yielded == "hello "
|
||||
|
||||
|
||||
class TestOpenAIResponseBuilder:
|
||||
@pytest.fixture
|
||||
def builder(self):
|
||||
builder = OpenAIResponseBuilder()
|
||||
req = MagicMock()
|
||||
req.messages = [MagicMock(role="user", content="Hello")]
|
||||
req.stop = None
|
||||
req.model = "astrai"
|
||||
engine = MagicMock()
|
||||
engine.tokenizer.apply_chat_template.return_value = "Hello"
|
||||
builder.prepare(req, engine)
|
||||
return builder
|
||||
|
||||
def test_prepare_returns_prompt_ctx_stops(self, builder):
|
||||
req = MagicMock()
|
||||
req.messages = [MagicMock(role="user", content="Hi")]
|
||||
req.stop = ["END"]
|
||||
req.model = "gpt"
|
||||
engine = MagicMock()
|
||||
engine.tokenizer.apply_chat_template.return_value = "Hi"
|
||||
prompt, ctx, stops = builder.prepare(req, engine)
|
||||
assert prompt == "Hi"
|
||||
assert ctx.model == "gpt"
|
||||
assert ctx.prompt_tokens == 0
|
||||
assert stops == ["END"]
|
||||
|
||||
def test_prepare_no_stop_returns_empty_list(self, builder):
|
||||
req = MagicMock()
|
||||
req.messages = []
|
||||
req.stop = None
|
||||
req.model = "x"
|
||||
engine = MagicMock()
|
||||
engine.tokenizer.apply_chat_template.return_value = ""
|
||||
_, _, stops = builder.prepare(req, engine)
|
||||
assert stops == []
|
||||
|
||||
def test_format_stream_start(self, builder):
|
||||
ctx = _make_ctx()
|
||||
events = builder.format_stream_start(ctx)
|
||||
payloads = _sse_payloads(events)
|
||||
assert len(payloads) == 1
|
||||
p = payloads[0]
|
||||
assert p["object"] == "chat.completion.chunk"
|
||||
assert p["choices"][0]["delta"]["role"] == "assistant"
|
||||
assert p["choices"][0]["finish_reason"] is None
|
||||
|
||||
def test_format_chunk(self, builder):
|
||||
ctx = _make_ctx()
|
||||
event = builder.format_chunk("hello")
|
||||
payload = json.loads(event.split("data: ", 1)[1])
|
||||
assert payload["choices"][0]["delta"]["content"] == "hello"
|
||||
assert payload["choices"][0]["finish_reason"] is None
|
||||
|
||||
def test_format_stream_end(self, builder):
|
||||
ctx = _make_ctx(completion_tokens=5)
|
||||
stop = StopInfo(matched="stop")
|
||||
events = builder.format_stream_end(ctx, stop)
|
||||
payloads = _sse_payloads(events)
|
||||
finish = payloads[0]
|
||||
assert finish["choices"][0]["finish_reason"] == "stop"
|
||||
usage = payloads[1]
|
||||
assert usage["completion_tokens"] == 5
|
||||
assert usage["total_tokens"] == 15
|
||||
|
||||
def test_format_response(self, builder):
|
||||
ctx = _make_ctx()
|
||||
stop = StopInfo()
|
||||
resp = builder.format_response(ctx, "hello", stop)
|
||||
assert resp["object"] == "chat.completion"
|
||||
assert resp["choices"][0]["message"]["content"] == "hello"
|
||||
assert resp["usage"]["prompt_tokens"] == 10
|
||||
|
||||
|
||||
class TestAnthropicResponseBuilder:
|
||||
@pytest.fixture
|
||||
def builder(self):
|
||||
builder = AnthropicResponseBuilder()
|
||||
req = MagicMock()
|
||||
req.messages = [MagicMock(role="user", content="Hello")]
|
||||
req.model = "claude"
|
||||
engine = MagicMock()
|
||||
engine.tokenizer.apply_chat_template.return_value = "Hello"
|
||||
req.system = None
|
||||
builder.prepare(req, engine)
|
||||
return builder
|
||||
|
||||
def test_prepare_messages(self, builder):
|
||||
req = MagicMock()
|
||||
req.messages = [MagicMock(role="user", content="Hi")]
|
||||
req.model = "claude"
|
||||
req.system = None
|
||||
req.stop_sequences = None
|
||||
engine = MagicMock()
|
||||
engine.tokenizer.apply_chat_template.return_value = "Hi"
|
||||
prompt, ctx, stops = builder.prepare(req, engine)
|
||||
assert prompt == "Hi"
|
||||
assert stops == []
|
||||
|
||||
def test_prepare_with_stop_sequences(self, builder):
|
||||
req = MagicMock()
|
||||
req.messages = []
|
||||
req.model = "x"
|
||||
req.stop_sequences = ["stop", "end"]
|
||||
req.system = None
|
||||
engine = MagicMock()
|
||||
engine.tokenizer.apply_chat_template.return_value = ""
|
||||
_, _, stops = builder.prepare(req, engine)
|
||||
assert stops == ["stop", "end"]
|
||||
|
||||
def test_format_stream_start(self, builder):
|
||||
ctx = _make_ctx(prompt_tokens=3)
|
||||
events = builder.format_stream_start(ctx)
|
||||
payloads = _sse_payloads(events)
|
||||
assert len(payloads) == 2
|
||||
assert payloads[0]["type"] == "message_start"
|
||||
assert payloads[0]["message"]["usage"]["input_tokens"] == 3
|
||||
assert payloads[1]["type"] == "content_block_start"
|
||||
|
||||
def test_format_chunk(self, builder):
|
||||
event = builder.format_chunk("tok")
|
||||
payload = json.loads(event.split("data: ", 1)[1])
|
||||
assert payload["type"] == "content_block_delta"
|
||||
assert payload["delta"]["text"] == "tok"
|
||||
|
||||
def test_format_stream_end_no_stop(self, builder):
|
||||
ctx = _make_ctx(completion_tokens=3)
|
||||
stop = StopInfo()
|
||||
events = builder.format_stream_end(ctx, stop)
|
||||
payloads = _sse_payloads(events)
|
||||
# content_block_stop, message_delta, message_stop
|
||||
types = [p["type"] for p in payloads]
|
||||
assert types == ["content_block_stop", "message_delta", "message_stop"]
|
||||
assert payloads[1]["delta"]["stop_reason"] == "end_turn"
|
||||
|
||||
def test_format_stream_end_with_stop_trims_and_emits_remaining(self, builder):
|
||||
ctx = _make_ctx(completion_tokens=7)
|
||||
stop = StopInfo(
|
||||
matched="END",
|
||||
body="Hello world END extra",
|
||||
yielded="Hello ",
|
||||
)
|
||||
events = builder.format_stream_end(ctx, stop)
|
||||
payloads = _sse_payloads(events)
|
||||
# unyielded delta, content_block_stop, message_delta, message_stop
|
||||
types = [p["type"] for p in payloads]
|
||||
assert types == [
|
||||
"content_block_delta",
|
||||
"content_block_stop",
|
||||
"message_delta",
|
||||
"message_stop",
|
||||
]
|
||||
assert payloads[0]["delta"]["text"] == "world "
|
||||
assert payloads[2]["delta"]["stop_reason"] == "stop_sequence"
|
||||
assert payloads[2]["delta"]["stop_sequence"] == "END"
|
||||
|
||||
def test_format_stream_end_stop_trimmed_already_yielded(self, builder):
|
||||
ctx = _make_ctx()
|
||||
stop = StopInfo(
|
||||
matched="END",
|
||||
body="Hello END",
|
||||
yielded="Hello ",
|
||||
)
|
||||
events = builder.format_stream_end(ctx, stop)
|
||||
payloads = _sse_payloads(events)
|
||||
# No unyielded delta (everything already sent)
|
||||
types = [p["type"] for p in payloads]
|
||||
assert types == ["content_block_stop", "message_delta", "message_stop"]
|
||||
|
||||
def test_format_response_with_stop_trims_content(self, builder):
|
||||
ctx = _make_ctx()
|
||||
stop = StopInfo(matched="STOP", body="text STOP extra", yielded="text ")
|
||||
resp = builder.format_response(ctx, "text STOP extra", stop)
|
||||
assert resp["content"][0]["text"] == "text "
|
||||
assert resp["stop_reason"] == "stop_sequence"
|
||||
assert resp["stop_sequence"] == "STOP"
|
||||
|
||||
def test_format_response_no_stop(self, builder):
|
||||
ctx = _make_ctx()
|
||||
stop = StopInfo()
|
||||
resp = builder.format_response(ctx, "full text", stop)
|
||||
assert resp["content"][0]["text"] == "full text"
|
||||
assert resp["stop_reason"] == "end_turn"
|
||||
|
||||
|
||||
class TestGenerationRequestValidation:
|
||||
def test_valid_params(self):
|
||||
gr = GenerationRequest(
|
||||
messages=[{"role": "user", "content": "hi"}],
|
||||
top_k=50,
|
||||
top_p=0.9,
|
||||
temperature=0.7,
|
||||
)
|
||||
assert gr.top_k == 50
|
||||
|
||||
def test_invalid_top_p_raises(self):
|
||||
with pytest.raises(ValueError, match="top_p"):
|
||||
GenerationRequest(messages=[{"role": "user", "content": "hi"}], top_p=1.5)
|
||||
|
||||
def test_invalid_top_k_raises(self):
|
||||
with pytest.raises(ValueError, match="top_k"):
|
||||
GenerationRequest(messages=[{"role": "user", "content": "hi"}], top_k=-1)
|
||||
|
||||
def test_invalid_temperature_raises(self):
|
||||
with pytest.raises(ValueError, match="temperature"):
|
||||
GenerationRequest(
|
||||
messages=[{"role": "user", "content": "hi"}], temperature=-0.1
|
||||
)
|
||||
|
||||
def test_top_k_zero_valid(self):
|
||||
gr = GenerationRequest(messages=[{"role": "user", "content": "hi"}], top_k=0)
|
||||
assert gr.top_k == 0
|
||||
|
|
@ -173,3 +173,21 @@ def test_scheduler_concurrent_get_stats(mock_model_and_tokenizer):
|
|||
for stats in results["stats"]:
|
||||
assert "total_tasks" in stats
|
||||
assert stats["total_tasks"] >= 0
|
||||
|
||||
|
||||
def test_prefill_skips_fully_cached_tasks(mock_model_and_tokenizer):
|
||||
"""Tasks whose entire prompt is cached skip the prefill phase."""
|
||||
mock_model, mock_tokenizer = mock_model_and_tokenizer
|
||||
|
||||
with patch("astrai.inference.core.scheduler.AutoModel"):
|
||||
with patch("astrai.inference.core.scheduler.AutoTokenizer"):
|
||||
scheduler = InferenceScheduler(
|
||||
model=mock_model,
|
||||
tokenizer=mock_tokenizer,
|
||||
max_batch_size=4,
|
||||
device="cpu",
|
||||
)
|
||||
|
||||
task_id = scheduler.add_task("short prompt", stream_callback=lambda t: None)
|
||||
scheduler.stop()
|
||||
assert task_id.startswith("task_")
|
||||
|
|
|
|||
Loading…
Reference in New Issue