fix: ProgressBar默认输出到stdout
- file参数默认值改为None, 内部用 or sys.stdout 兜底 - 清理inference API中未使用的import (Optional, time, field) - 删除test_protocol中未使用的ctx变量
This commit is contained in:
parent
94d6e713e9
commit
dd1b39f435
|
|
@ -1,7 +1,7 @@
|
|||
"""Anthropic message completion response builder."""
|
||||
|
||||
import uuid
|
||||
from typing import Any, Dict, List, Optional, Tuple, Union
|
||||
from typing import Any, Dict, List, Tuple, Union
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
|
|
|||
|
|
@ -1,7 +1,7 @@
|
|||
"""OpenAI chat completion response builder."""
|
||||
|
||||
import uuid
|
||||
from typing import Any, Dict, List, Optional, Tuple
|
||||
from typing import Any, Dict, List, Tuple
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
|
|
|||
|
|
@ -5,9 +5,8 @@ protocol-specific formatting to a ResponseBuilder.
|
|||
"""
|
||||
|
||||
import json
|
||||
import time
|
||||
from abc import ABC, abstractmethod
|
||||
from dataclasses import dataclass, field
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, AsyncGenerator, Dict, List, Optional, Tuple, Union
|
||||
|
||||
from fastapi.responses import StreamingResponse
|
||||
|
|
|
|||
|
|
@ -210,7 +210,7 @@ class ProgressBarCallback(TrainCallback):
|
|||
"""
|
||||
|
||||
def __init__(
|
||||
self, num_epoch: int, log_interval: int = 100, file: IO[str] = sys.stdout
|
||||
self, num_epoch: int, log_interval: int = 100, file: Optional[IO[str]] = None
|
||||
):
|
||||
self.num_epoch = num_epoch
|
||||
self.log_interval = log_interval
|
||||
|
|
@ -223,7 +223,7 @@ class ProgressBarCallback(TrainCallback):
|
|||
context.dataloader,
|
||||
desc=f"Epoch {context.epoch + 1}/{self.num_epoch}",
|
||||
dynamic_ncols=True,
|
||||
file=self.file,
|
||||
file=self.file or sys.stdout,
|
||||
)
|
||||
|
||||
@only_on_rank(0)
|
||||
|
|
|
|||
|
|
@ -121,7 +121,6 @@ class TestOpenAIResponseBuilder:
|
|||
assert p["choices"][0]["finish_reason"] is None
|
||||
|
||||
def test_format_chunk(self, builder):
|
||||
ctx = _make_ctx()
|
||||
event = builder.format_chunk("hello")
|
||||
payload = json.loads(event.split("data: ", 1)[1])
|
||||
assert payload["choices"][0]["delta"]["content"] == "hello"
|
||||
|
|
|
|||
Loading…
Reference in New Issue