diff --git a/astrai/inference/api/anthropic.py b/astrai/inference/api/anthropic.py index 74e6990..526554a 100644 --- a/astrai/inference/api/anthropic.py +++ b/astrai/inference/api/anthropic.py @@ -1,7 +1,7 @@ """Anthropic message completion response builder.""" import uuid -from typing import Any, Dict, List, Optional, Tuple, Union +from typing import Any, Dict, List, Tuple, Union from pydantic import BaseModel diff --git a/astrai/inference/api/openai.py b/astrai/inference/api/openai.py index 25035ad..5e86437 100644 --- a/astrai/inference/api/openai.py +++ b/astrai/inference/api/openai.py @@ -1,7 +1,7 @@ """OpenAI chat completion response builder.""" import uuid -from typing import Any, Dict, List, Optional, Tuple +from typing import Any, Dict, List, Tuple from pydantic import BaseModel diff --git a/astrai/inference/api/protocol.py b/astrai/inference/api/protocol.py index 09822c2..45c51b5 100644 --- a/astrai/inference/api/protocol.py +++ b/astrai/inference/api/protocol.py @@ -5,9 +5,8 @@ protocol-specific formatting to a ResponseBuilder. """ import json -import time from abc import ABC, abstractmethod -from dataclasses import dataclass, field +from dataclasses import dataclass from typing import Any, AsyncGenerator, Dict, List, Optional, Tuple, Union from fastapi.responses import StreamingResponse diff --git a/astrai/trainer/train_callback.py b/astrai/trainer/train_callback.py index d061fbe..ee55a43 100644 --- a/astrai/trainer/train_callback.py +++ b/astrai/trainer/train_callback.py @@ -210,7 +210,7 @@ class ProgressBarCallback(TrainCallback): """ def __init__( - self, num_epoch: int, log_interval: int = 100, file: IO[str] = sys.stdout + self, num_epoch: int, log_interval: int = 100, file: Optional[IO[str]] = None ): self.num_epoch = num_epoch self.log_interval = log_interval @@ -223,7 +223,7 @@ class ProgressBarCallback(TrainCallback): context.dataloader, desc=f"Epoch {context.epoch + 1}/{self.num_epoch}", dynamic_ncols=True, - file=self.file, + file=self.file or sys.stdout, ) @only_on_rank(0) diff --git a/tests/inference/test_protocol.py b/tests/inference/test_protocol.py index bb39b3a..76049b2 100644 --- a/tests/inference/test_protocol.py +++ b/tests/inference/test_protocol.py @@ -121,7 +121,6 @@ class TestOpenAIResponseBuilder: assert p["choices"][0]["finish_reason"] is None def test_format_chunk(self, builder): - ctx = _make_ctx() event = builder.format_chunk("hello") payload = json.loads(event.split("data: ", 1)[1]) assert payload["choices"][0]["delta"]["content"] == "hello"