From dd1b39f4357a2117d66baf2eceb0328f12a01c1e Mon Sep 17 00:00:00 2001 From: ViperEkura <3081035982@qq.com> Date: Tue, 26 May 2026 13:27:05 +0800 Subject: [PATCH] =?UTF-8?q?fix:=20ProgressBar=E9=BB=98=E8=AE=A4=E8=BE=93?= =?UTF-8?q?=E5=87=BA=E5=88=B0stdout?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - file参数默认值改为None, 内部用 or sys.stdout 兜底 - 清理inference API中未使用的import (Optional, time, field) - 删除test_protocol中未使用的ctx变量 --- astrai/inference/api/anthropic.py | 2 +- astrai/inference/api/openai.py | 2 +- astrai/inference/api/protocol.py | 3 +-- astrai/trainer/train_callback.py | 4 ++-- tests/inference/test_protocol.py | 1 - 5 files changed, 5 insertions(+), 7 deletions(-) diff --git a/astrai/inference/api/anthropic.py b/astrai/inference/api/anthropic.py index 74e6990..526554a 100644 --- a/astrai/inference/api/anthropic.py +++ b/astrai/inference/api/anthropic.py @@ -1,7 +1,7 @@ """Anthropic message completion response builder.""" import uuid -from typing import Any, Dict, List, Optional, Tuple, Union +from typing import Any, Dict, List, Tuple, Union from pydantic import BaseModel diff --git a/astrai/inference/api/openai.py b/astrai/inference/api/openai.py index 25035ad..5e86437 100644 --- a/astrai/inference/api/openai.py +++ b/astrai/inference/api/openai.py @@ -1,7 +1,7 @@ """OpenAI chat completion response builder.""" import uuid -from typing import Any, Dict, List, Optional, Tuple +from typing import Any, Dict, List, Tuple from pydantic import BaseModel diff --git a/astrai/inference/api/protocol.py b/astrai/inference/api/protocol.py index 09822c2..45c51b5 100644 --- a/astrai/inference/api/protocol.py +++ b/astrai/inference/api/protocol.py @@ -5,9 +5,8 @@ protocol-specific formatting to a ResponseBuilder. """ import json -import time from abc import ABC, abstractmethod -from dataclasses import dataclass, field +from dataclasses import dataclass from typing import Any, AsyncGenerator, Dict, List, Optional, Tuple, Union from fastapi.responses import StreamingResponse diff --git a/astrai/trainer/train_callback.py b/astrai/trainer/train_callback.py index d061fbe..ee55a43 100644 --- a/astrai/trainer/train_callback.py +++ b/astrai/trainer/train_callback.py @@ -210,7 +210,7 @@ class ProgressBarCallback(TrainCallback): """ def __init__( - self, num_epoch: int, log_interval: int = 100, file: IO[str] = sys.stdout + self, num_epoch: int, log_interval: int = 100, file: Optional[IO[str]] = None ): self.num_epoch = num_epoch self.log_interval = log_interval @@ -223,7 +223,7 @@ class ProgressBarCallback(TrainCallback): context.dataloader, desc=f"Epoch {context.epoch + 1}/{self.num_epoch}", dynamic_ncols=True, - file=self.file, + file=self.file or sys.stdout, ) @only_on_rank(0) diff --git a/tests/inference/test_protocol.py b/tests/inference/test_protocol.py index bb39b3a..76049b2 100644 --- a/tests/inference/test_protocol.py +++ b/tests/inference/test_protocol.py @@ -121,7 +121,6 @@ class TestOpenAIResponseBuilder: assert p["choices"][0]["finish_reason"] is None def test_format_chunk(self, builder): - ctx = _make_ctx() event = builder.format_chunk("hello") payload = json.loads(event.split("data: ", 1)[1]) assert payload["choices"][0]["delta"]["content"] == "hello"