fix: 修复 to_dict list 类型丢失与 OpenAI stop 参数失效

- to_dict() 增加 list 类型序列化支持,metrics 等字段不再丢失
- OpenAIHandler 补充 get_stop_sequences/on_token,读取 request.stop 并检测停止序列
- 文档类图补充缺失字段、修正关系分类、ChatCompletionRequest 字段增加 Optional
This commit is contained in:
ViperEkura 2026-05-19 21:00:40 +08:00
parent 64be81b7b3
commit 785d65436c
6 changed files with 90 additions and 14 deletions

View File

@ -77,6 +77,9 @@ classDiagram
+int start_batch
+str ckpt_dir
+int ckpt_interval
+str log_dir
+int log_interval
+List[str] metrics
+int random_seed
+int num_workers
+Optional[int] prefetch_factor
@ -472,6 +475,10 @@ classDiagram
class CheckpointCallback {
+str save_dir
+int interval
+bool weight_only
+Callable state_dict_fn
+Callable save_extra_fn
+Callable load_extra_fn
+_save_checkpoint(context)
+on_train_begin(context)
+on_batch_end(context)
@ -483,6 +490,8 @@ classDiagram
class ProgressBarCallback {
+int num_epoch
+int log_interval
+IO file
+on_epoch_begin(context)
+on_batch_end(context)
+on_epoch_end(context)
@ -491,6 +500,8 @@ classDiagram
class MetricLoggerCallback {
+str log_dir
+int save_interval
+int log_interval
+List[str] metrics
+on_batch_end(context)
+on_train_end(context)
+on_error(context)
@ -687,7 +698,7 @@ classDiagram
}
class SamplingPipeline {
+List strategies
+List[BaseSamplingStrategy] strategies
+apply(logits, filter_value) Tensor
+sample(logits, filter_value) Tensor
}
@ -711,16 +722,16 @@ classDiagram
class ChatCompletionRequest {
+str model
+List[ChatMessage] messages
+float temperature
+float top_p
+int top_k
+int max_tokens
+bool stream
+Optional[float] temperature
+Optional[float] top_p
+Optional[int] top_k
+Optional[int] max_tokens
+Optional[bool] stream
+Optional[Union[str, List[str]]] stop
+Optional[int] n
+Optional[float] presence_penalty
+Optional[float] frequency_penalty
+Optional[Dict] logit_bias
+Optional[Dict[int, float]] logit_bias
+Optional[str] user
}
@ -872,7 +883,6 @@ classDiagram
InferenceScheduler *-- KVCache
InferenceScheduler *-- Executor
InferenceScheduler *-- TaskManager
SamplingPipeline *-- BaseSamplingStrategy
AutoRegressiveLM *-- DecoderBlock
AutoRegressiveLM *-- RotaryEmbedding
AutoRegressiveLM *-- Embedding
@ -880,9 +890,10 @@ classDiagram
EmbeddingEncoder *-- RotaryEmbedding
EmbeddingEncoder *-- Embedding
DecoderBlock *-- RMSNorm
BaseDataset o-- BaseStorage
ChatCompletionRequest *-- ChatMessage
MessagesRequest *-- AnthropicMessage
AutoTokenizer *-- ChatTemplate
BaseFactory *-- Registry
%% --- Aggregation (weak ownership) ---
AutoModel o-- BaseModelConfig
@ -890,9 +901,9 @@ classDiagram
TrainContext o-- BaseStrategy
TrainContext o-- BaseScheduler
TrainContext o-- Checkpoint
AutoTokenizer o-- ChatTemplate
KvcacheView o-- Storage
BaseFactory o-- Registry
SamplingPipeline o-- BaseSamplingStrategy
BaseDataset o-- BaseStorage
%% --- Dependency (uses temporarily) ---
TrainConfig ..> BaseStrategy : selects

View File

@ -1,4 +1,4 @@
__version__ = "1.3.5"
__version__ = "1.3.6"
__author__ = "ViperEkura"
from astrai.config import (

View File

@ -11,7 +11,6 @@ __all__ = [
"BaseModelConfig",
"AutoRegressiveLMConfig",
"EncoderConfig",
"ModelConfig",
"ConfigFactory",
"TrainConfig",
]

View File

@ -13,7 +13,7 @@ class BaseConfig:
d[fld.name] = v
elif v is None:
d[fld.name] = None
elif isinstance(v, dict):
elif isinstance(v, (dict, list)):
try:
json.dumps(v)
d[fld.name] = v

View File

@ -226,6 +226,17 @@ class OpenAIHandler(ProtocolHandler):
def create_response_id(self) -> str:
return f"chatcmpl-{uuid.uuid4().hex[:12]}"
def get_stop_sequences(self) -> List[str]:
stop = self.request.stop
if stop is None:
return []
return [stop] if isinstance(stop, str) else stop
def on_token(
self, ctx: StreamContext, token: str, stop_checker: StopChecker
) -> Optional[str]:
return stop_checker.check(ctx.accumulated)
def format_stream_start(self, ctx: StreamContext) -> List[str]:
return [
_sse_event(

View File

@ -157,5 +157,60 @@ def test_messages_with_system(client, loaded_model):
assert data["type"] == "message"
def test_chat_completions_stop_sequence(client, loaded_model):
"""POST /v1/chat/completions with stop parameter truncates at stop sequence."""
async def async_gen():
yield "Hello"
yield "X"
yield "world"
app.state.engine = loaded_model
loaded_model.generate_async.return_value = async_gen()
response = client.post(
"/v1/chat/completions",
json={
"messages": [{"role": "user", "content": "Hello"}],
"max_tokens": 100,
"stream": False,
"stop": ["X"],
},
)
assert response.status_code == 200
data = response.json()
content = data["choices"][0]["message"]["content"]
assert "X" in content
assert "world" not in content
def test_chat_completions_stop_sequence_stream(client, loaded_model):
"""POST /v1/chat/completions with stop parameter truncates SSE stream."""
async def async_gen():
yield "Hello"
yield "X"
yield "world"
app.state.engine = loaded_model
loaded_model.generate_async.return_value = async_gen()
response = client.post(
"/v1/chat/completions",
json={
"messages": [{"role": "user", "content": "Hello"}],
"max_tokens": 100,
"stream": True,
"stop": ["X"],
},
headers={"Accept": "text/event-stream"},
)
assert response.status_code == 200
content = response.content.decode("utf-8")
assert "Hello" in content
assert "world" not in content
assert any(
"finish_reason" in line for line in content.split("\n") if "stop" in line
)
if __name__ == "__main__":
pytest.main([__file__, "-v"])