fix: ProgressBar默认输出到stdout

- file参数默认值改为None, 内部用 or sys.stdout 兜底
- 清理inference API中未使用的import (Optional, time, field)
- 删除test_protocol中未使用的ctx变量
This commit is contained in:
ViperEkura 2026-05-26 13:27:05 +08:00
parent 94d6e713e9
commit dd1b39f435
5 changed files with 5 additions and 7 deletions

View File

@ -1,7 +1,7 @@
"""Anthropic message completion response builder.""" """Anthropic message completion response builder."""
import uuid import uuid
from typing import Any, Dict, List, Optional, Tuple, Union from typing import Any, Dict, List, Tuple, Union
from pydantic import BaseModel from pydantic import BaseModel

View File

@ -1,7 +1,7 @@
"""OpenAI chat completion response builder.""" """OpenAI chat completion response builder."""
import uuid import uuid
from typing import Any, Dict, List, Optional, Tuple from typing import Any, Dict, List, Tuple
from pydantic import BaseModel from pydantic import BaseModel

View File

@ -5,9 +5,8 @@ protocol-specific formatting to a ResponseBuilder.
""" """
import json import json
import time
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from dataclasses import dataclass, field from dataclasses import dataclass
from typing import Any, AsyncGenerator, Dict, List, Optional, Tuple, Union from typing import Any, AsyncGenerator, Dict, List, Optional, Tuple, Union
from fastapi.responses import StreamingResponse from fastapi.responses import StreamingResponse

View File

@ -210,7 +210,7 @@ class ProgressBarCallback(TrainCallback):
""" """
def __init__( def __init__(
self, num_epoch: int, log_interval: int = 100, file: IO[str] = sys.stdout self, num_epoch: int, log_interval: int = 100, file: Optional[IO[str]] = None
): ):
self.num_epoch = num_epoch self.num_epoch = num_epoch
self.log_interval = log_interval self.log_interval = log_interval
@ -223,7 +223,7 @@ class ProgressBarCallback(TrainCallback):
context.dataloader, context.dataloader,
desc=f"Epoch {context.epoch + 1}/{self.num_epoch}", desc=f"Epoch {context.epoch + 1}/{self.num_epoch}",
dynamic_ncols=True, dynamic_ncols=True,
file=self.file, file=self.file or sys.stdout,
) )
@only_on_rank(0) @only_on_rank(0)

View File

@ -121,7 +121,6 @@ class TestOpenAIResponseBuilder:
assert p["choices"][0]["finish_reason"] is None assert p["choices"][0]["finish_reason"] is None
def test_format_chunk(self, builder): def test_format_chunk(self, builder):
ctx = _make_ctx()
event = builder.format_chunk("hello") event = builder.format_chunk("hello")
payload = json.loads(event.split("data: ", 1)[1]) payload = json.loads(event.split("data: ", 1)[1])
assert payload["choices"][0]["delta"]["content"] == "hello" assert payload["choices"][0]["delta"]["content"] == "hello"