fix: 修复 to_dict list 类型丢失与 OpenAI stop 参数失效
- to_dict() 增加 list 类型序列化支持,metrics 等字段不再丢失 - OpenAIHandler 补充 get_stop_sequences/on_token,读取 request.stop 并检测停止序列 - 文档类图补充缺失字段、修正关系分类、ChatCompletionRequest 字段增加 Optional
This commit is contained in:
parent
64be81b7b3
commit
785d65436c
|
|
@ -77,6 +77,9 @@ classDiagram
|
||||||
+int start_batch
|
+int start_batch
|
||||||
+str ckpt_dir
|
+str ckpt_dir
|
||||||
+int ckpt_interval
|
+int ckpt_interval
|
||||||
|
+str log_dir
|
||||||
|
+int log_interval
|
||||||
|
+List[str] metrics
|
||||||
+int random_seed
|
+int random_seed
|
||||||
+int num_workers
|
+int num_workers
|
||||||
+Optional[int] prefetch_factor
|
+Optional[int] prefetch_factor
|
||||||
|
|
@ -472,6 +475,10 @@ classDiagram
|
||||||
class CheckpointCallback {
|
class CheckpointCallback {
|
||||||
+str save_dir
|
+str save_dir
|
||||||
+int interval
|
+int interval
|
||||||
|
+bool weight_only
|
||||||
|
+Callable state_dict_fn
|
||||||
|
+Callable save_extra_fn
|
||||||
|
+Callable load_extra_fn
|
||||||
+_save_checkpoint(context)
|
+_save_checkpoint(context)
|
||||||
+on_train_begin(context)
|
+on_train_begin(context)
|
||||||
+on_batch_end(context)
|
+on_batch_end(context)
|
||||||
|
|
@ -483,6 +490,8 @@ classDiagram
|
||||||
|
|
||||||
class ProgressBarCallback {
|
class ProgressBarCallback {
|
||||||
+int num_epoch
|
+int num_epoch
|
||||||
|
+int log_interval
|
||||||
|
+IO file
|
||||||
+on_epoch_begin(context)
|
+on_epoch_begin(context)
|
||||||
+on_batch_end(context)
|
+on_batch_end(context)
|
||||||
+on_epoch_end(context)
|
+on_epoch_end(context)
|
||||||
|
|
@ -491,6 +500,8 @@ classDiagram
|
||||||
class MetricLoggerCallback {
|
class MetricLoggerCallback {
|
||||||
+str log_dir
|
+str log_dir
|
||||||
+int save_interval
|
+int save_interval
|
||||||
|
+int log_interval
|
||||||
|
+List[str] metrics
|
||||||
+on_batch_end(context)
|
+on_batch_end(context)
|
||||||
+on_train_end(context)
|
+on_train_end(context)
|
||||||
+on_error(context)
|
+on_error(context)
|
||||||
|
|
@ -687,7 +698,7 @@ classDiagram
|
||||||
}
|
}
|
||||||
|
|
||||||
class SamplingPipeline {
|
class SamplingPipeline {
|
||||||
+List strategies
|
+List[BaseSamplingStrategy] strategies
|
||||||
+apply(logits, filter_value) Tensor
|
+apply(logits, filter_value) Tensor
|
||||||
+sample(logits, filter_value) Tensor
|
+sample(logits, filter_value) Tensor
|
||||||
}
|
}
|
||||||
|
|
@ -711,16 +722,16 @@ classDiagram
|
||||||
class ChatCompletionRequest {
|
class ChatCompletionRequest {
|
||||||
+str model
|
+str model
|
||||||
+List[ChatMessage] messages
|
+List[ChatMessage] messages
|
||||||
+float temperature
|
+Optional[float] temperature
|
||||||
+float top_p
|
+Optional[float] top_p
|
||||||
+int top_k
|
+Optional[int] top_k
|
||||||
+int max_tokens
|
+Optional[int] max_tokens
|
||||||
+bool stream
|
+Optional[bool] stream
|
||||||
+Optional[Union[str, List[str]]] stop
|
+Optional[Union[str, List[str]]] stop
|
||||||
+Optional[int] n
|
+Optional[int] n
|
||||||
+Optional[float] presence_penalty
|
+Optional[float] presence_penalty
|
||||||
+Optional[float] frequency_penalty
|
+Optional[float] frequency_penalty
|
||||||
+Optional[Dict] logit_bias
|
+Optional[Dict[int, float]] logit_bias
|
||||||
+Optional[str] user
|
+Optional[str] user
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -872,7 +883,6 @@ classDiagram
|
||||||
InferenceScheduler *-- KVCache
|
InferenceScheduler *-- KVCache
|
||||||
InferenceScheduler *-- Executor
|
InferenceScheduler *-- Executor
|
||||||
InferenceScheduler *-- TaskManager
|
InferenceScheduler *-- TaskManager
|
||||||
SamplingPipeline *-- BaseSamplingStrategy
|
|
||||||
AutoRegressiveLM *-- DecoderBlock
|
AutoRegressiveLM *-- DecoderBlock
|
||||||
AutoRegressiveLM *-- RotaryEmbedding
|
AutoRegressiveLM *-- RotaryEmbedding
|
||||||
AutoRegressiveLM *-- Embedding
|
AutoRegressiveLM *-- Embedding
|
||||||
|
|
@ -880,9 +890,10 @@ classDiagram
|
||||||
EmbeddingEncoder *-- RotaryEmbedding
|
EmbeddingEncoder *-- RotaryEmbedding
|
||||||
EmbeddingEncoder *-- Embedding
|
EmbeddingEncoder *-- Embedding
|
||||||
DecoderBlock *-- RMSNorm
|
DecoderBlock *-- RMSNorm
|
||||||
BaseDataset o-- BaseStorage
|
|
||||||
ChatCompletionRequest *-- ChatMessage
|
ChatCompletionRequest *-- ChatMessage
|
||||||
MessagesRequest *-- AnthropicMessage
|
MessagesRequest *-- AnthropicMessage
|
||||||
|
AutoTokenizer *-- ChatTemplate
|
||||||
|
BaseFactory *-- Registry
|
||||||
|
|
||||||
%% --- Aggregation (weak ownership) ---
|
%% --- Aggregation (weak ownership) ---
|
||||||
AutoModel o-- BaseModelConfig
|
AutoModel o-- BaseModelConfig
|
||||||
|
|
@ -890,9 +901,9 @@ classDiagram
|
||||||
TrainContext o-- BaseStrategy
|
TrainContext o-- BaseStrategy
|
||||||
TrainContext o-- BaseScheduler
|
TrainContext o-- BaseScheduler
|
||||||
TrainContext o-- Checkpoint
|
TrainContext o-- Checkpoint
|
||||||
AutoTokenizer o-- ChatTemplate
|
|
||||||
KvcacheView o-- Storage
|
KvcacheView o-- Storage
|
||||||
BaseFactory o-- Registry
|
SamplingPipeline o-- BaseSamplingStrategy
|
||||||
|
BaseDataset o-- BaseStorage
|
||||||
|
|
||||||
%% --- Dependency (uses temporarily) ---
|
%% --- Dependency (uses temporarily) ---
|
||||||
TrainConfig ..> BaseStrategy : selects
|
TrainConfig ..> BaseStrategy : selects
|
||||||
|
|
|
||||||
|
|
@ -1,4 +1,4 @@
|
||||||
__version__ = "1.3.5"
|
__version__ = "1.3.6"
|
||||||
__author__ = "ViperEkura"
|
__author__ = "ViperEkura"
|
||||||
|
|
||||||
from astrai.config import (
|
from astrai.config import (
|
||||||
|
|
|
||||||
|
|
@ -11,7 +11,6 @@ __all__ = [
|
||||||
"BaseModelConfig",
|
"BaseModelConfig",
|
||||||
"AutoRegressiveLMConfig",
|
"AutoRegressiveLMConfig",
|
||||||
"EncoderConfig",
|
"EncoderConfig",
|
||||||
"ModelConfig",
|
|
||||||
"ConfigFactory",
|
"ConfigFactory",
|
||||||
"TrainConfig",
|
"TrainConfig",
|
||||||
]
|
]
|
||||||
|
|
|
||||||
|
|
@ -13,7 +13,7 @@ class BaseConfig:
|
||||||
d[fld.name] = v
|
d[fld.name] = v
|
||||||
elif v is None:
|
elif v is None:
|
||||||
d[fld.name] = None
|
d[fld.name] = None
|
||||||
elif isinstance(v, dict):
|
elif isinstance(v, (dict, list)):
|
||||||
try:
|
try:
|
||||||
json.dumps(v)
|
json.dumps(v)
|
||||||
d[fld.name] = v
|
d[fld.name] = v
|
||||||
|
|
|
||||||
|
|
@ -226,6 +226,17 @@ class OpenAIHandler(ProtocolHandler):
|
||||||
def create_response_id(self) -> str:
|
def create_response_id(self) -> str:
|
||||||
return f"chatcmpl-{uuid.uuid4().hex[:12]}"
|
return f"chatcmpl-{uuid.uuid4().hex[:12]}"
|
||||||
|
|
||||||
|
def get_stop_sequences(self) -> List[str]:
|
||||||
|
stop = self.request.stop
|
||||||
|
if stop is None:
|
||||||
|
return []
|
||||||
|
return [stop] if isinstance(stop, str) else stop
|
||||||
|
|
||||||
|
def on_token(
|
||||||
|
self, ctx: StreamContext, token: str, stop_checker: StopChecker
|
||||||
|
) -> Optional[str]:
|
||||||
|
return stop_checker.check(ctx.accumulated)
|
||||||
|
|
||||||
def format_stream_start(self, ctx: StreamContext) -> List[str]:
|
def format_stream_start(self, ctx: StreamContext) -> List[str]:
|
||||||
return [
|
return [
|
||||||
_sse_event(
|
_sse_event(
|
||||||
|
|
|
||||||
|
|
@ -157,5 +157,60 @@ def test_messages_with_system(client, loaded_model):
|
||||||
assert data["type"] == "message"
|
assert data["type"] == "message"
|
||||||
|
|
||||||
|
|
||||||
|
def test_chat_completions_stop_sequence(client, loaded_model):
|
||||||
|
"""POST /v1/chat/completions with stop parameter truncates at stop sequence."""
|
||||||
|
|
||||||
|
async def async_gen():
|
||||||
|
yield "Hello"
|
||||||
|
yield "X"
|
||||||
|
yield "world"
|
||||||
|
|
||||||
|
app.state.engine = loaded_model
|
||||||
|
loaded_model.generate_async.return_value = async_gen()
|
||||||
|
response = client.post(
|
||||||
|
"/v1/chat/completions",
|
||||||
|
json={
|
||||||
|
"messages": [{"role": "user", "content": "Hello"}],
|
||||||
|
"max_tokens": 100,
|
||||||
|
"stream": False,
|
||||||
|
"stop": ["X"],
|
||||||
|
},
|
||||||
|
)
|
||||||
|
assert response.status_code == 200
|
||||||
|
data = response.json()
|
||||||
|
content = data["choices"][0]["message"]["content"]
|
||||||
|
assert "X" in content
|
||||||
|
assert "world" not in content
|
||||||
|
|
||||||
|
|
||||||
|
def test_chat_completions_stop_sequence_stream(client, loaded_model):
|
||||||
|
"""POST /v1/chat/completions with stop parameter truncates SSE stream."""
|
||||||
|
|
||||||
|
async def async_gen():
|
||||||
|
yield "Hello"
|
||||||
|
yield "X"
|
||||||
|
yield "world"
|
||||||
|
|
||||||
|
app.state.engine = loaded_model
|
||||||
|
loaded_model.generate_async.return_value = async_gen()
|
||||||
|
response = client.post(
|
||||||
|
"/v1/chat/completions",
|
||||||
|
json={
|
||||||
|
"messages": [{"role": "user", "content": "Hello"}],
|
||||||
|
"max_tokens": 100,
|
||||||
|
"stream": True,
|
||||||
|
"stop": ["X"],
|
||||||
|
},
|
||||||
|
headers={"Accept": "text/event-stream"},
|
||||||
|
)
|
||||||
|
assert response.status_code == 200
|
||||||
|
content = response.content.decode("utf-8")
|
||||||
|
assert "Hello" in content
|
||||||
|
assert "world" not in content
|
||||||
|
assert any(
|
||||||
|
"finish_reason" in line for line in content.split("\n") if "stop" in line
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
pytest.main([__file__, "-v"])
|
pytest.main([__file__, "-v"])
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue