Compare commits
No commits in common. "785d65436c53de0fcbc28416bad6a7a4b9d8607f" and "6c8533f1d278191ea74749adda20ef461245f49d" have entirely different histories.
785d65436c
...
6c8533f1d2
|
|
@ -1,7 +1,7 @@
|
||||||
# AstrAI Dockerfile - Multi-stage Build (Optimized)
|
# AstrAI Dockerfile - Multi-stage Build (Optimized)
|
||||||
|
|
||||||
# Build stage - use base image with minimal build tools
|
# Build stage - use base image with minimal build tools
|
||||||
FROM ubuntu:24.04 AS builder
|
FROM nvidia/cuda:12.6.0-base-ubuntu24.04 AS builder
|
||||||
|
|
||||||
WORKDIR /app
|
WORKDIR /app
|
||||||
|
|
||||||
|
|
@ -18,7 +18,7 @@ RUN apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install -y --no-ins
|
||||||
RUN python3.12 -m venv --copies /opt/venv
|
RUN python3.12 -m venv --copies /opt/venv
|
||||||
ENV PATH="/opt/venv/bin:$PATH"
|
ENV PATH="/opt/venv/bin:$PATH"
|
||||||
|
|
||||||
# Copy source code and install (deps read from pyproject.toml)
|
# Copy source code and install dependencies
|
||||||
COPY astrai/ ./astrai/
|
COPY astrai/ ./astrai/
|
||||||
COPY pyproject.toml .
|
COPY pyproject.toml .
|
||||||
RUN pip install --no-cache-dir --upgrade pip \
|
RUN pip install --no-cache-dir --upgrade pip \
|
||||||
|
|
@ -26,14 +26,13 @@ RUN pip install --no-cache-dir --upgrade pip \
|
||||||
--extra-index-url https://download.pytorch.org/whl/cu126
|
--extra-index-url https://download.pytorch.org/whl/cu126
|
||||||
|
|
||||||
# Production stage
|
# Production stage
|
||||||
FROM ubuntu:24.04 AS production
|
FROM nvidia/cuda:12.6.0-base-ubuntu24.04 AS production
|
||||||
|
|
||||||
WORKDIR /app
|
WORKDIR /app
|
||||||
|
|
||||||
# Install Python 3.12 runtime and healthcheck dependency
|
# Install Python 3.12 runtime
|
||||||
RUN apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install -y --no-install-recommends \
|
RUN apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install -y --no-install-recommends \
|
||||||
python3.12 \
|
python3.12 \
|
||||||
curl \
|
|
||||||
&& rm -rf /var/lib/apt/lists/*
|
&& rm -rf /var/lib/apt/lists/*
|
||||||
|
|
||||||
# Copy virtual environment from builder
|
# Copy virtual environment from builder
|
||||||
|
|
|
||||||
|
|
@ -213,7 +213,7 @@ python scripts/demo/generate_batch.py
|
||||||
python scripts/demo/generate_ar.py
|
python scripts/demo/generate_ar.py
|
||||||
```
|
```
|
||||||
|
|
||||||
Watch a video walkthrough on [bilibili](https://www.bilibili.com/video/BV1fuLB6yEj6).
|
Watch a video walkthrough on [bilibili](https://www.bilibili.com/video/BV1z5RPYHEkd).
|
||||||
|
|
||||||
### Documentation
|
### Documentation
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -219,7 +219,7 @@ python scripts/demo/generate_batch.py
|
||||||
python scripts/demo/generate_ar.py
|
python scripts/demo/generate_ar.py
|
||||||
```
|
```
|
||||||
|
|
||||||
观看 [bilibili](https://www.bilibili.com/video/BV1fuLB6yEj6) 上的视频演示。
|
观看 [bilibili](https://www.bilibili.com/video/BV1z5RPYHEkd) 上的视频演示。
|
||||||
|
|
||||||
### 文档
|
### 文档
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -77,9 +77,6 @@ classDiagram
|
||||||
+int start_batch
|
+int start_batch
|
||||||
+str ckpt_dir
|
+str ckpt_dir
|
||||||
+int ckpt_interval
|
+int ckpt_interval
|
||||||
+str log_dir
|
|
||||||
+int log_interval
|
|
||||||
+List[str] metrics
|
|
||||||
+int random_seed
|
+int random_seed
|
||||||
+int num_workers
|
+int num_workers
|
||||||
+Optional[int] prefetch_factor
|
+Optional[int] prefetch_factor
|
||||||
|
|
@ -475,10 +472,6 @@ classDiagram
|
||||||
class CheckpointCallback {
|
class CheckpointCallback {
|
||||||
+str save_dir
|
+str save_dir
|
||||||
+int interval
|
+int interval
|
||||||
+bool weight_only
|
|
||||||
+Callable state_dict_fn
|
|
||||||
+Callable save_extra_fn
|
|
||||||
+Callable load_extra_fn
|
|
||||||
+_save_checkpoint(context)
|
+_save_checkpoint(context)
|
||||||
+on_train_begin(context)
|
+on_train_begin(context)
|
||||||
+on_batch_end(context)
|
+on_batch_end(context)
|
||||||
|
|
@ -490,8 +483,6 @@ classDiagram
|
||||||
|
|
||||||
class ProgressBarCallback {
|
class ProgressBarCallback {
|
||||||
+int num_epoch
|
+int num_epoch
|
||||||
+int log_interval
|
|
||||||
+IO file
|
|
||||||
+on_epoch_begin(context)
|
+on_epoch_begin(context)
|
||||||
+on_batch_end(context)
|
+on_batch_end(context)
|
||||||
+on_epoch_end(context)
|
+on_epoch_end(context)
|
||||||
|
|
@ -500,8 +491,6 @@ classDiagram
|
||||||
class MetricLoggerCallback {
|
class MetricLoggerCallback {
|
||||||
+str log_dir
|
+str log_dir
|
||||||
+int save_interval
|
+int save_interval
|
||||||
+int log_interval
|
|
||||||
+List[str] metrics
|
|
||||||
+on_batch_end(context)
|
+on_batch_end(context)
|
||||||
+on_train_end(context)
|
+on_train_end(context)
|
||||||
+on_error(context)
|
+on_error(context)
|
||||||
|
|
@ -698,7 +687,7 @@ classDiagram
|
||||||
}
|
}
|
||||||
|
|
||||||
class SamplingPipeline {
|
class SamplingPipeline {
|
||||||
+List[BaseSamplingStrategy] strategies
|
+List strategies
|
||||||
+apply(logits, filter_value) Tensor
|
+apply(logits, filter_value) Tensor
|
||||||
+sample(logits, filter_value) Tensor
|
+sample(logits, filter_value) Tensor
|
||||||
}
|
}
|
||||||
|
|
@ -722,16 +711,16 @@ classDiagram
|
||||||
class ChatCompletionRequest {
|
class ChatCompletionRequest {
|
||||||
+str model
|
+str model
|
||||||
+List[ChatMessage] messages
|
+List[ChatMessage] messages
|
||||||
+Optional[float] temperature
|
+float temperature
|
||||||
+Optional[float] top_p
|
+float top_p
|
||||||
+Optional[int] top_k
|
+int top_k
|
||||||
+Optional[int] max_tokens
|
+int max_tokens
|
||||||
+Optional[bool] stream
|
+bool stream
|
||||||
+Optional[Union[str, List[str]]] stop
|
+Optional[Union[str, List[str]]] stop
|
||||||
+Optional[int] n
|
+Optional[int] n
|
||||||
+Optional[float] presence_penalty
|
+Optional[float] presence_penalty
|
||||||
+Optional[float] frequency_penalty
|
+Optional[float] frequency_penalty
|
||||||
+Optional[Dict[int, float]] logit_bias
|
+Optional[Dict] logit_bias
|
||||||
+Optional[str] user
|
+Optional[str] user
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -883,6 +872,7 @@ classDiagram
|
||||||
InferenceScheduler *-- KVCache
|
InferenceScheduler *-- KVCache
|
||||||
InferenceScheduler *-- Executor
|
InferenceScheduler *-- Executor
|
||||||
InferenceScheduler *-- TaskManager
|
InferenceScheduler *-- TaskManager
|
||||||
|
SamplingPipeline *-- BaseSamplingStrategy
|
||||||
AutoRegressiveLM *-- DecoderBlock
|
AutoRegressiveLM *-- DecoderBlock
|
||||||
AutoRegressiveLM *-- RotaryEmbedding
|
AutoRegressiveLM *-- RotaryEmbedding
|
||||||
AutoRegressiveLM *-- Embedding
|
AutoRegressiveLM *-- Embedding
|
||||||
|
|
@ -890,10 +880,9 @@ classDiagram
|
||||||
EmbeddingEncoder *-- RotaryEmbedding
|
EmbeddingEncoder *-- RotaryEmbedding
|
||||||
EmbeddingEncoder *-- Embedding
|
EmbeddingEncoder *-- Embedding
|
||||||
DecoderBlock *-- RMSNorm
|
DecoderBlock *-- RMSNorm
|
||||||
|
BaseDataset o-- BaseStorage
|
||||||
ChatCompletionRequest *-- ChatMessage
|
ChatCompletionRequest *-- ChatMessage
|
||||||
MessagesRequest *-- AnthropicMessage
|
MessagesRequest *-- AnthropicMessage
|
||||||
AutoTokenizer *-- ChatTemplate
|
|
||||||
BaseFactory *-- Registry
|
|
||||||
|
|
||||||
%% --- Aggregation (weak ownership) ---
|
%% --- Aggregation (weak ownership) ---
|
||||||
AutoModel o-- BaseModelConfig
|
AutoModel o-- BaseModelConfig
|
||||||
|
|
@ -901,9 +890,9 @@ classDiagram
|
||||||
TrainContext o-- BaseStrategy
|
TrainContext o-- BaseStrategy
|
||||||
TrainContext o-- BaseScheduler
|
TrainContext o-- BaseScheduler
|
||||||
TrainContext o-- Checkpoint
|
TrainContext o-- Checkpoint
|
||||||
|
AutoTokenizer o-- ChatTemplate
|
||||||
KvcacheView o-- Storage
|
KvcacheView o-- Storage
|
||||||
SamplingPipeline o-- BaseSamplingStrategy
|
BaseFactory o-- Registry
|
||||||
BaseDataset o-- BaseStorage
|
|
||||||
|
|
||||||
%% --- Dependency (uses temporarily) ---
|
%% --- Dependency (uses temporarily) ---
|
||||||
TrainConfig ..> BaseStrategy : selects
|
TrainConfig ..> BaseStrategy : selects
|
||||||
|
|
|
||||||
|
|
@ -1,4 +1,4 @@
|
||||||
__version__ = "1.3.6"
|
__version__ = "1.3.5"
|
||||||
__author__ = "ViperEkura"
|
__author__ = "ViperEkura"
|
||||||
|
|
||||||
from astrai.config import (
|
from astrai.config import (
|
||||||
|
|
|
||||||
|
|
@ -11,6 +11,7 @@ __all__ = [
|
||||||
"BaseModelConfig",
|
"BaseModelConfig",
|
||||||
"AutoRegressiveLMConfig",
|
"AutoRegressiveLMConfig",
|
||||||
"EncoderConfig",
|
"EncoderConfig",
|
||||||
|
"ModelConfig",
|
||||||
"ConfigFactory",
|
"ConfigFactory",
|
||||||
"TrainConfig",
|
"TrainConfig",
|
||||||
]
|
]
|
||||||
|
|
|
||||||
|
|
@ -13,7 +13,7 @@ class BaseConfig:
|
||||||
d[fld.name] = v
|
d[fld.name] = v
|
||||||
elif v is None:
|
elif v is None:
|
||||||
d[fld.name] = None
|
d[fld.name] = None
|
||||||
elif isinstance(v, (dict, list)):
|
elif isinstance(v, dict):
|
||||||
try:
|
try:
|
||||||
json.dumps(v)
|
json.dumps(v)
|
||||||
d[fld.name] = v
|
d[fld.name] = v
|
||||||
|
|
|
||||||
|
|
@ -1,5 +1,5 @@
|
||||||
from dataclasses import dataclass, field, fields
|
from dataclasses import dataclass, field, fields
|
||||||
from typing import Callable, List, Optional
|
from typing import Callable, Optional
|
||||||
|
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
from torch.optim import Optimizer
|
from torch.optim import Optimizer
|
||||||
|
|
@ -56,19 +56,6 @@ class TrainConfig(BaseConfig):
|
||||||
default=5000, metadata={"help": "Number of iterations between checkpoints."}
|
default=5000, metadata={"help": "Number of iterations between checkpoints."}
|
||||||
)
|
)
|
||||||
|
|
||||||
# metric setting
|
|
||||||
log_dir: str = field(
|
|
||||||
default="./checkpoint/logs", metadata={"help": "Directory for metric logs."}
|
|
||||||
)
|
|
||||||
log_interval: int = field(
|
|
||||||
default=100,
|
|
||||||
metadata={"help": "Number of batch iterations between metric logs."},
|
|
||||||
)
|
|
||||||
metrics: List[str] = field(
|
|
||||||
default_factory=lambda: ["loss", "lr"],
|
|
||||||
metadata={"help": "Metrics to record during training."},
|
|
||||||
)
|
|
||||||
|
|
||||||
# dataloader setting
|
# dataloader setting
|
||||||
random_seed: int = field(default=3407, metadata={"help": "Random seed."})
|
random_seed: int = field(default=3407, metadata={"help": "Random seed."})
|
||||||
num_workers: int = field(
|
num_workers: int = field(
|
||||||
|
|
|
||||||
|
|
@ -226,17 +226,6 @@ class OpenAIHandler(ProtocolHandler):
|
||||||
def create_response_id(self) -> str:
|
def create_response_id(self) -> str:
|
||||||
return f"chatcmpl-{uuid.uuid4().hex[:12]}"
|
return f"chatcmpl-{uuid.uuid4().hex[:12]}"
|
||||||
|
|
||||||
def get_stop_sequences(self) -> List[str]:
|
|
||||||
stop = self.request.stop
|
|
||||||
if stop is None:
|
|
||||||
return []
|
|
||||||
return [stop] if isinstance(stop, str) else stop
|
|
||||||
|
|
||||||
def on_token(
|
|
||||||
self, ctx: StreamContext, token: str, stop_checker: StopChecker
|
|
||||||
) -> Optional[str]:
|
|
||||||
return stop_checker.check(ctx.accumulated)
|
|
||||||
|
|
||||||
def format_stream_start(self, ctx: StreamContext) -> List[str]:
|
def format_stream_start(self, ctx: StreamContext) -> List[str]:
|
||||||
return [
|
return [
|
||||||
_sse_event(
|
_sse_event(
|
||||||
|
|
|
||||||
|
|
@ -163,4 +163,5 @@ def spawn_parallel_fn(
|
||||||
nprocs=world_size,
|
nprocs=world_size,
|
||||||
start_method=start_method,
|
start_method=start_method,
|
||||||
join=True,
|
join=True,
|
||||||
|
daemon=True,
|
||||||
)
|
)
|
||||||
|
|
|
||||||
|
|
@ -38,7 +38,7 @@ class Checkpoint:
|
||||||
meta = {
|
meta = {
|
||||||
"epoch": self.epoch,
|
"epoch": self.epoch,
|
||||||
"iteration": self.iteration,
|
"iteration": self.iteration,
|
||||||
"timestamp": time.strftime("%Y-%m-%dT%H:%M:%S"),
|
"timestamp": time.time(),
|
||||||
}
|
}
|
||||||
meta.update(self.meta)
|
meta.update(self.meta)
|
||||||
with open(save_path / "meta.json", "w") as f:
|
with open(save_path / "meta.json", "w") as f:
|
||||||
|
|
|
||||||
|
|
@ -1,10 +1,9 @@
|
||||||
import json
|
import json
|
||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
import sys
|
|
||||||
import time
|
import time
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import IO, Callable, List, Optional, Protocol, runtime_checkable
|
from typing import Callable, List, Optional, Protocol, runtime_checkable
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.distributed as dist
|
import torch.distributed as dist
|
||||||
|
|
@ -212,12 +211,8 @@ class ProgressBarCallback(TrainCallback):
|
||||||
Progress bar callback for trainer.
|
Progress bar callback for trainer.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(
|
def __init__(self, num_epoch: int):
|
||||||
self, num_epoch: int, log_interval: int = 100, file: IO[str] = sys.stdout
|
|
||||||
):
|
|
||||||
self.num_epoch = num_epoch
|
self.num_epoch = num_epoch
|
||||||
self.log_interval = log_interval
|
|
||||||
self.file = file
|
|
||||||
self.progress_bar: tqdm = None
|
self.progress_bar: tqdm = None
|
||||||
|
|
||||||
@only_on_rank(0)
|
@only_on_rank(0)
|
||||||
|
|
@ -226,7 +221,6 @@ class ProgressBarCallback(TrainCallback):
|
||||||
context.dataloader,
|
context.dataloader,
|
||||||
desc=f"Epoch {context.epoch + 1}/{self.num_epoch}",
|
desc=f"Epoch {context.epoch + 1}/{self.num_epoch}",
|
||||||
dynamic_ncols=True,
|
dynamic_ncols=True,
|
||||||
file=self.file,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
@only_on_rank(0)
|
@only_on_rank(0)
|
||||||
|
|
@ -280,7 +274,7 @@ class MetricLoggerCallback(TrainCallback):
|
||||||
|
|
||||||
def _get_log_data(self, context: TrainContext):
|
def _get_log_data(self, context: TrainContext):
|
||||||
return {
|
return {
|
||||||
"timestamp": time.strftime("%Y-%m-%dT%H:%M:%S"),
|
"timestamp": time.strftime("%Y-%m-%d %H:%M:%S"),
|
||||||
"epoch": context.epoch,
|
"epoch": context.epoch,
|
||||||
"iter": context.iteration,
|
"iter": context.iteration,
|
||||||
**{m: self._metric_funcs[m](context) for m in self.metrics},
|
**{m: self._metric_funcs[m](context) for m in self.metrics},
|
||||||
|
|
|
||||||
|
|
@ -36,14 +36,8 @@ class Trainer:
|
||||||
cfg.ckpt_interval,
|
cfg.ckpt_interval,
|
||||||
state_dict_fn=cfg.state_dict_fn,
|
state_dict_fn=cfg.state_dict_fn,
|
||||||
),
|
),
|
||||||
CallbackFactory.create(
|
|
||||||
"metric_logger",
|
|
||||||
log_dir=cfg.log_dir,
|
|
||||||
save_interval=cfg.ckpt_interval,
|
|
||||||
log_interval=cfg.log_interval,
|
|
||||||
metrics=cfg.metrics,
|
|
||||||
),
|
|
||||||
CallbackFactory.create("progress_bar", cfg.n_epoch),
|
CallbackFactory.create("progress_bar", cfg.n_epoch),
|
||||||
|
CallbackFactory.create("metric_logger", cfg.ckpt_dir, cfg.ckpt_interval),
|
||||||
CallbackFactory.create("gradient_clipping", cfg.max_grad_norm),
|
CallbackFactory.create("gradient_clipping", cfg.max_grad_norm),
|
||||||
CallbackFactory.create("validation"),
|
CallbackFactory.create("validation"),
|
||||||
]
|
]
|
||||||
|
|
|
||||||
|
|
@ -1,13 +1,12 @@
|
||||||
services:
|
services:
|
||||||
server:
|
server:
|
||||||
build:
|
build: .
|
||||||
context: .
|
image: astrai:latest
|
||||||
dockerfile: Dockerfile
|
|
||||||
user: "${UID:-1000}:${GID:-1000}"
|
|
||||||
ports:
|
ports:
|
||||||
- "8000:8000"
|
- "8000:8000"
|
||||||
volumes:
|
volumes:
|
||||||
- ./params:/app/params:ro
|
- ./params:/app/params:ro
|
||||||
|
- ./checkpoints:/app/checkpoints
|
||||||
command: python -m scripts.tools.server --port 8000 --device cuda
|
command: python -m scripts.tools.server --port 8000 --device cuda
|
||||||
deploy:
|
deploy:
|
||||||
resources:
|
resources:
|
||||||
|
|
@ -26,14 +25,13 @@ services:
|
||||||
|
|
||||||
server-cpu:
|
server-cpu:
|
||||||
profiles: [cpu]
|
profiles: [cpu]
|
||||||
build:
|
build: .
|
||||||
context: .
|
image: astrai:latest
|
||||||
dockerfile: Dockerfile
|
|
||||||
user: "${UID:-1000}:${GID:-1000}"
|
|
||||||
ports:
|
ports:
|
||||||
- "8000:8000"
|
- "8000:8000"
|
||||||
volumes:
|
volumes:
|
||||||
- ./params:/app/params:ro
|
- ./params:/app/params:ro
|
||||||
|
- ./checkpoints:/app/checkpoints
|
||||||
command: python -m scripts.tools.server --port 8000 --device cpu
|
command: python -m scripts.tools.server --port 8000 --device cpu
|
||||||
healthcheck:
|
healthcheck:
|
||||||
test: ["CMD", "curl", "-f", "http://localhost:8000/health"]
|
test: ["CMD", "curl", "-f", "http://localhost:8000/health"]
|
||||||
|
|
|
||||||
|
|
@ -16,7 +16,6 @@ NC='\033[0m' # No Color
|
||||||
IMAGE_NAME="astrai"
|
IMAGE_NAME="astrai"
|
||||||
IMAGE_TAG="latest"
|
IMAGE_TAG="latest"
|
||||||
REGISTRY=""
|
REGISTRY=""
|
||||||
CONTAINER_ID=""
|
|
||||||
|
|
||||||
# Print colored messages
|
# Print colored messages
|
||||||
print_info() {
|
print_info() {
|
||||||
|
|
@ -176,10 +175,6 @@ main() {
|
||||||
PORT="$2"
|
PORT="$2"
|
||||||
shift 2
|
shift 2
|
||||||
;;
|
;;
|
||||||
--container)
|
|
||||||
CONTAINER_ID="$2"
|
|
||||||
shift 2
|
|
||||||
;;
|
|
||||||
--gpu)
|
--gpu)
|
||||||
GPU=true
|
GPU=true
|
||||||
shift
|
shift
|
||||||
|
|
@ -202,7 +197,6 @@ main() {
|
||||||
echo " --dockerfile FILE Dockerfile path (default: Dockerfile)"
|
echo " --dockerfile FILE Dockerfile path (default: Dockerfile)"
|
||||||
echo " --context PATH Build context (default: .)"
|
echo " --context PATH Build context (default: .)"
|
||||||
echo " --port PORT Port for run (default: 8000)"
|
echo " --port PORT Port for run (default: 8000)"
|
||||||
echo " --container ID Container ID for logs"
|
|
||||||
echo " --gpu Enable GPU support"
|
echo " --gpu Enable GPU support"
|
||||||
echo " --help Show this help message"
|
echo " --help Show this help message"
|
||||||
echo ""
|
echo ""
|
||||||
|
|
@ -211,7 +205,6 @@ main() {
|
||||||
echo " $0 build --tag v1.0.0"
|
echo " $0 build --tag v1.0.0"
|
||||||
echo " $0 run --port 8080"
|
echo " $0 run --port 8080"
|
||||||
echo " $0 run --gpu"
|
echo " $0 run --gpu"
|
||||||
echo " $0 logs --container abc123"
|
|
||||||
echo " $0 push --registry ghcr.io/username"
|
echo " $0 push --registry ghcr.io/username"
|
||||||
exit 0
|
exit 0
|
||||||
;;
|
;;
|
||||||
|
|
@ -244,7 +237,7 @@ main() {
|
||||||
show_info
|
show_info
|
||||||
;;
|
;;
|
||||||
logs)
|
logs)
|
||||||
show_logs "$CONTAINER_ID"
|
show_logs "$2"
|
||||||
;;
|
;;
|
||||||
"")
|
"")
|
||||||
print_error "No command specified. Use --help for usage"
|
print_error "No command specified. Use --help for usage"
|
||||||
|
|
|
||||||
|
|
@ -157,60 +157,5 @@ def test_messages_with_system(client, loaded_model):
|
||||||
assert data["type"] == "message"
|
assert data["type"] == "message"
|
||||||
|
|
||||||
|
|
||||||
def test_chat_completions_stop_sequence(client, loaded_model):
|
|
||||||
"""POST /v1/chat/completions with stop parameter truncates at stop sequence."""
|
|
||||||
|
|
||||||
async def async_gen():
|
|
||||||
yield "Hello"
|
|
||||||
yield "X"
|
|
||||||
yield "world"
|
|
||||||
|
|
||||||
app.state.engine = loaded_model
|
|
||||||
loaded_model.generate_async.return_value = async_gen()
|
|
||||||
response = client.post(
|
|
||||||
"/v1/chat/completions",
|
|
||||||
json={
|
|
||||||
"messages": [{"role": "user", "content": "Hello"}],
|
|
||||||
"max_tokens": 100,
|
|
||||||
"stream": False,
|
|
||||||
"stop": ["X"],
|
|
||||||
},
|
|
||||||
)
|
|
||||||
assert response.status_code == 200
|
|
||||||
data = response.json()
|
|
||||||
content = data["choices"][0]["message"]["content"]
|
|
||||||
assert "X" in content
|
|
||||||
assert "world" not in content
|
|
||||||
|
|
||||||
|
|
||||||
def test_chat_completions_stop_sequence_stream(client, loaded_model):
|
|
||||||
"""POST /v1/chat/completions with stop parameter truncates SSE stream."""
|
|
||||||
|
|
||||||
async def async_gen():
|
|
||||||
yield "Hello"
|
|
||||||
yield "X"
|
|
||||||
yield "world"
|
|
||||||
|
|
||||||
app.state.engine = loaded_model
|
|
||||||
loaded_model.generate_async.return_value = async_gen()
|
|
||||||
response = client.post(
|
|
||||||
"/v1/chat/completions",
|
|
||||||
json={
|
|
||||||
"messages": [{"role": "user", "content": "Hello"}],
|
|
||||||
"max_tokens": 100,
|
|
||||||
"stream": True,
|
|
||||||
"stop": ["X"],
|
|
||||||
},
|
|
||||||
headers={"Accept": "text/event-stream"},
|
|
||||||
)
|
|
||||||
assert response.status_code == 200
|
|
||||||
content = response.content.decode("utf-8")
|
|
||||||
assert "Hello" in content
|
|
||||||
assert "world" not in content
|
|
||||||
assert any(
|
|
||||||
"finish_reason" in line for line in content.split("\n") if "stop" in line
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
pytest.main([__file__, "-v"])
|
pytest.main([__file__, "-v"])
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue