diff --git a/assets/docs/architecture.md b/assets/docs/architecture.md index c3ac94b..1997318 100644 --- a/assets/docs/architecture.md +++ b/assets/docs/architecture.md @@ -77,6 +77,9 @@ classDiagram +int start_batch +str ckpt_dir +int ckpt_interval + +str log_dir + +int log_interval + +List[str] metrics +int random_seed +int num_workers +Optional[int] prefetch_factor @@ -472,6 +475,10 @@ classDiagram class CheckpointCallback { +str save_dir +int interval + +bool weight_only + +Callable state_dict_fn + +Callable save_extra_fn + +Callable load_extra_fn +_save_checkpoint(context) +on_train_begin(context) +on_batch_end(context) @@ -483,6 +490,8 @@ classDiagram class ProgressBarCallback { +int num_epoch + +int log_interval + +IO file +on_epoch_begin(context) +on_batch_end(context) +on_epoch_end(context) @@ -491,6 +500,8 @@ classDiagram class MetricLoggerCallback { +str log_dir +int save_interval + +int log_interval + +List[str] metrics +on_batch_end(context) +on_train_end(context) +on_error(context) @@ -687,7 +698,7 @@ classDiagram } class SamplingPipeline { - +List strategies + +List[BaseSamplingStrategy] strategies +apply(logits, filter_value) Tensor +sample(logits, filter_value) Tensor } @@ -711,16 +722,16 @@ classDiagram class ChatCompletionRequest { +str model +List[ChatMessage] messages - +float temperature - +float top_p - +int top_k - +int max_tokens - +bool stream + +Optional[float] temperature + +Optional[float] top_p + +Optional[int] top_k + +Optional[int] max_tokens + +Optional[bool] stream +Optional[Union[str, List[str]]] stop +Optional[int] n +Optional[float] presence_penalty +Optional[float] frequency_penalty - +Optional[Dict] logit_bias + +Optional[Dict[int, float]] logit_bias +Optional[str] user } @@ -872,7 +883,6 @@ classDiagram InferenceScheduler *-- KVCache InferenceScheduler *-- Executor InferenceScheduler *-- TaskManager - SamplingPipeline *-- BaseSamplingStrategy AutoRegressiveLM *-- DecoderBlock AutoRegressiveLM *-- RotaryEmbedding AutoRegressiveLM *-- Embedding @@ -880,9 +890,10 @@ classDiagram EmbeddingEncoder *-- RotaryEmbedding EmbeddingEncoder *-- Embedding DecoderBlock *-- RMSNorm - BaseDataset o-- BaseStorage ChatCompletionRequest *-- ChatMessage MessagesRequest *-- AnthropicMessage + AutoTokenizer *-- ChatTemplate + BaseFactory *-- Registry %% --- Aggregation (weak ownership) --- AutoModel o-- BaseModelConfig @@ -890,9 +901,9 @@ classDiagram TrainContext o-- BaseStrategy TrainContext o-- BaseScheduler TrainContext o-- Checkpoint - AutoTokenizer o-- ChatTemplate KvcacheView o-- Storage - BaseFactory o-- Registry + SamplingPipeline o-- BaseSamplingStrategy + BaseDataset o-- BaseStorage %% --- Dependency (uses temporarily) --- TrainConfig ..> BaseStrategy : selects diff --git a/astrai/__init__.py b/astrai/__init__.py index ef408bf..a42f83c 100644 --- a/astrai/__init__.py +++ b/astrai/__init__.py @@ -1,4 +1,4 @@ -__version__ = "1.3.5" +__version__ = "1.3.6" __author__ = "ViperEkura" from astrai.config import ( diff --git a/astrai/config/__init__.py b/astrai/config/__init__.py index 6158147..e72b596 100644 --- a/astrai/config/__init__.py +++ b/astrai/config/__init__.py @@ -11,7 +11,6 @@ __all__ = [ "BaseModelConfig", "AutoRegressiveLMConfig", "EncoderConfig", - "ModelConfig", "ConfigFactory", "TrainConfig", ] diff --git a/astrai/config/base.py b/astrai/config/base.py index 0c6182c..1d34295 100644 --- a/astrai/config/base.py +++ b/astrai/config/base.py @@ -13,7 +13,7 @@ class BaseConfig: d[fld.name] = v elif v is None: d[fld.name] = None - elif isinstance(v, dict): + elif isinstance(v, (dict, list)): try: json.dumps(v) d[fld.name] = v diff --git a/astrai/inference/api/protocol.py b/astrai/inference/api/protocol.py index 9fc449a..da13cdf 100644 --- a/astrai/inference/api/protocol.py +++ b/astrai/inference/api/protocol.py @@ -226,6 +226,17 @@ class OpenAIHandler(ProtocolHandler): def create_response_id(self) -> str: return f"chatcmpl-{uuid.uuid4().hex[:12]}" + def get_stop_sequences(self) -> List[str]: + stop = self.request.stop + if stop is None: + return [] + return [stop] if isinstance(stop, str) else stop + + def on_token( + self, ctx: StreamContext, token: str, stop_checker: StopChecker + ) -> Optional[str]: + return stop_checker.check(ctx.accumulated) + def format_stream_start(self, ctx: StreamContext) -> List[str]: return [ _sse_event( diff --git a/tests/inference/test_server.py b/tests/inference/test_server.py index def7329..00584cc 100644 --- a/tests/inference/test_server.py +++ b/tests/inference/test_server.py @@ -157,5 +157,60 @@ def test_messages_with_system(client, loaded_model): assert data["type"] == "message" +def test_chat_completions_stop_sequence(client, loaded_model): + """POST /v1/chat/completions with stop parameter truncates at stop sequence.""" + + async def async_gen(): + yield "Hello" + yield "X" + yield "world" + + app.state.engine = loaded_model + loaded_model.generate_async.return_value = async_gen() + response = client.post( + "/v1/chat/completions", + json={ + "messages": [{"role": "user", "content": "Hello"}], + "max_tokens": 100, + "stream": False, + "stop": ["X"], + }, + ) + assert response.status_code == 200 + data = response.json() + content = data["choices"][0]["message"]["content"] + assert "X" in content + assert "world" not in content + + +def test_chat_completions_stop_sequence_stream(client, loaded_model): + """POST /v1/chat/completions with stop parameter truncates SSE stream.""" + + async def async_gen(): + yield "Hello" + yield "X" + yield "world" + + app.state.engine = loaded_model + loaded_model.generate_async.return_value = async_gen() + response = client.post( + "/v1/chat/completions", + json={ + "messages": [{"role": "user", "content": "Hello"}], + "max_tokens": 100, + "stream": True, + "stop": ["X"], + }, + headers={"Accept": "text/event-stream"}, + ) + assert response.status_code == 200 + content = response.content.decode("utf-8") + assert "Hello" in content + assert "world" not in content + assert any( + "finish_reason" in line for line in content.split("\n") if "stop" in line + ) + + if __name__ == "__main__": pytest.main([__file__, "-v"])