fix: ProgressBar默认输出到stdout
- file参数默认值改为None, 内部用 or sys.stdout 兜底 - 清理inference API中未使用的import (Optional, time, field) - 删除test_protocol中未使用的ctx变量
This commit is contained in:
parent
94d6e713e9
commit
dd1b39f435
|
|
@ -1,7 +1,7 @@
|
||||||
"""Anthropic message completion response builder."""
|
"""Anthropic message completion response builder."""
|
||||||
|
|
||||||
import uuid
|
import uuid
|
||||||
from typing import Any, Dict, List, Optional, Tuple, Union
|
from typing import Any, Dict, List, Tuple, Union
|
||||||
|
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -1,7 +1,7 @@
|
||||||
"""OpenAI chat completion response builder."""
|
"""OpenAI chat completion response builder."""
|
||||||
|
|
||||||
import uuid
|
import uuid
|
||||||
from typing import Any, Dict, List, Optional, Tuple
|
from typing import Any, Dict, List, Tuple
|
||||||
|
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -5,9 +5,8 @@ protocol-specific formatting to a ResponseBuilder.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import json
|
import json
|
||||||
import time
|
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from dataclasses import dataclass, field
|
from dataclasses import dataclass
|
||||||
from typing import Any, AsyncGenerator, Dict, List, Optional, Tuple, Union
|
from typing import Any, AsyncGenerator, Dict, List, Optional, Tuple, Union
|
||||||
|
|
||||||
from fastapi.responses import StreamingResponse
|
from fastapi.responses import StreamingResponse
|
||||||
|
|
|
||||||
|
|
@ -210,7 +210,7 @@ class ProgressBarCallback(TrainCallback):
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self, num_epoch: int, log_interval: int = 100, file: IO[str] = sys.stdout
|
self, num_epoch: int, log_interval: int = 100, file: Optional[IO[str]] = None
|
||||||
):
|
):
|
||||||
self.num_epoch = num_epoch
|
self.num_epoch = num_epoch
|
||||||
self.log_interval = log_interval
|
self.log_interval = log_interval
|
||||||
|
|
@ -223,7 +223,7 @@ class ProgressBarCallback(TrainCallback):
|
||||||
context.dataloader,
|
context.dataloader,
|
||||||
desc=f"Epoch {context.epoch + 1}/{self.num_epoch}",
|
desc=f"Epoch {context.epoch + 1}/{self.num_epoch}",
|
||||||
dynamic_ncols=True,
|
dynamic_ncols=True,
|
||||||
file=self.file,
|
file=self.file or sys.stdout,
|
||||||
)
|
)
|
||||||
|
|
||||||
@only_on_rank(0)
|
@only_on_rank(0)
|
||||||
|
|
|
||||||
|
|
@ -121,7 +121,6 @@ class TestOpenAIResponseBuilder:
|
||||||
assert p["choices"][0]["finish_reason"] is None
|
assert p["choices"][0]["finish_reason"] is None
|
||||||
|
|
||||||
def test_format_chunk(self, builder):
|
def test_format_chunk(self, builder):
|
||||||
ctx = _make_ctx()
|
|
||||||
event = builder.format_chunk("hello")
|
event = builder.format_chunk("hello")
|
||||||
payload = json.loads(event.split("data: ", 1)[1])
|
payload = json.loads(event.split("data: ", 1)[1])
|
||||||
assert payload["choices"][0]["delta"]["content"] == "hello"
|
assert payload["choices"][0]["delta"]["content"] == "hello"
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue