diff --git a/README.md b/README.md index b049505..d72d743 100644 --- a/README.md +++ b/README.md @@ -47,6 +47,7 @@ - 📦 **Lightweight**: Minimal dependencies, easy to deploy. - 🔬 **Research‑Friendly**: Modular design, easy to experiment with new ideas. - 🤗 **HuggingFace Integration**: Compatible with HuggingFace models and datasets. +- 🔌 **Dual API Compatibility**: Supports both OpenAI and Anthropic chat completion APIs out of the box. ### Quick Start @@ -67,27 +68,9 @@ pip install -e ".[dev]" #### Train a Model ```bash -python scripts/tools/train.py \ - --train_type=seq \ - --data_root_path=/path/to/dataset \ - --param_path=/path/to/model \ - --n_epoch=3 \ - --batch_size=4 \ - --accumulation_steps=8 \ - --max_lr=3e-4 \ - --warmup_steps=2000 \ - --ckpt_interval=5000 \ - --ckpt_dir=./checkpoints +python scripts/tools/train.py --train_type=seq --data_root_path=/path/to/dataset --param_path=/path/to/model ``` -#### Generate Text - -```bash -python scripts/tools/generate.py --param_path=/path/to/param_path -``` - -#### Training Parameters - | Parameter | Description | Default | |-----------|-------------|---------| | `--train_type` | Training type (`seq`, `sft`, `dpo`) | required | @@ -105,6 +88,12 @@ python scripts/tools/generate.py --param_path=/path/to/param_path Full reference at [Parameter Guide](./assets/docs/params.md#training-parameters). +#### Generate Text + +```bash +python scripts/tools/generate.py --param_path=/path/to/param_path +``` + #### Docker Build and run with Docker (recommended for GPU environments): @@ -131,7 +120,7 @@ docker run --gpus all -v /path/to/data:/data -it astrai:latest #### Start HTTP Server -Start the inference server with OpenAI-compatible HTTP API: +Start the inference server with OpenAI and Anthropic-compatible HTTP API: ```bash python -m scripts.tools.server --port 8000 --device cuda @@ -140,7 +129,7 @@ python -m scripts.tools.server --port 8000 --device cuda Make requests: ```bash -# Chat API (OpenAI compatible) +# OpenAI-compatible curl -X POST http://localhost:8000/v1/chat/completions \ -H "Content-Type: application/json" \ -d '{ @@ -148,7 +137,7 @@ curl -X POST http://localhost:8000/v1/chat/completions \ "max_tokens": 512 }' -# Streaming response +# OpenAI-compatible streaming curl -X POST http://localhost:8000/v1/chat/completions \ -H "Content-Type: application/json" \ -d '{ @@ -157,6 +146,27 @@ curl -X POST http://localhost:8000/v1/chat/completions \ "max_tokens": 500 }' +# Anthropic-compatible +curl -X POST http://localhost:8000/v1/messages \ + -H "Content-Type: application/json" \ + -d '{ + "model": "astrai", + "system": "You are a helpful assistant.", + "messages": [{"role": "user", "content": "Hello"}], + "max_tokens": 512 + }' + +# Anthropic-compatible streaming with stop sequences +curl -X POST http://localhost:8000/v1/messages \ + -H "Content-Type: application/json" \ + -d '{ + "model": "astrai", + "messages": [{"role": "user", "content": "Write a story"}], + "max_tokens": 500, + "stream": true, + "stop_sequences": ["The end"] + }' + # Health check curl http://localhost:8000/health ``` diff --git a/assets/docs/README-zh-CN.md b/assets/docs/README-zh-CN.md index ff4dc99..92f9e96 100644 --- a/assets/docs/README-zh-CN.md +++ b/assets/docs/README-zh-CN.md @@ -53,6 +53,7 @@ - 📦 **轻量**: 依赖少,部署简单。 - 🔬 **研究友好**: 模块化设计,便于实验新想法。 - 🤗 **HuggingFace 集成**: 兼容 HuggingFace 模型与数据集。 +- 🔌 **双 API 兼容**: 同时支持 OpenAI 和 Anthropic 聊天补全 API,开箱即用。 ### 快速开始 @@ -73,27 +74,9 @@ pip install -e ".[dev]" #### 训练模型 ```bash -python scripts/tools/train.py \ - --train_type=seq \ - --data_root_path=/path/to/dataset \ - --param_path=/path/to/model \ - --n_epoch=3 \ - --batch_size=4 \ - --accumulation_steps=8 \ - --max_lr=3e-4 \ - --warmup_steps=2000 \ - --ckpt_interval=5000 \ - --ckpt_dir=./checkpoints +python scripts/tools/train.py --train_type=seq --data_root_path=/path/to/dataset --param_path=/path/to/model ``` -#### 文本生成 - -```bash -python scripts/tools/generate.py --param_path=/path/to/param_path -``` - -#### 训练参数 - | 参数 | 说明 | 默认值 | |------|------|--------| | `--train_type` | 训练类型(`seq`, `sft`, `dpo`) | 必填 | @@ -111,6 +94,12 @@ python scripts/tools/generate.py --param_path=/path/to/param_path 完整参数列表见[参数说明](./params.md#training-parameters)。 +#### 文本生成 + +```bash +python scripts/tools/generate.py --param_path=/path/to/param_path +``` + #### Docker 使用 Docker 构建和运行(推荐用于 GPU 环境): @@ -137,7 +126,7 @@ docker run --gpus all -v /path/to/data:/data -it astrai:latest #### 启动 HTTP 服务 -启动推理服务器,支持 OpenAI 兼容的 HTTP API: +启动推理服务器,支持 OpenAI 和 Anthropic 兼容的 HTTP API: ```bash python -m scripts.tools.server --port 8000 --device cuda @@ -146,7 +135,7 @@ python -m scripts.tools.server --port 8000 --device cuda 发起请求: ```bash -# Chat API(OpenAI 兼容) +# OpenAI 兼容 curl -X POST http://localhost:8000/v1/chat/completions \ -H "Content-Type: application/json" \ -d '{ @@ -154,7 +143,7 @@ curl -X POST http://localhost:8000/v1/chat/completions \ "max_tokens": 512 }' -# 流式响应 +# OpenAI 兼容流式 curl -X POST http://localhost:8000/v1/chat/completions \ -H "Content-Type: application/json" \ -d '{ @@ -163,6 +152,27 @@ curl -X POST http://localhost:8000/v1/chat/completions \ "max_tokens": 500 }' +# Anthropic 兼容 +curl -X POST http://localhost:8000/v1/messages \ + -H "Content-Type: application/json" \ + -d '{ + "model": "astrai", + "system": "你是一个乐于助人的助手。", + "messages": [{"role": "user", "content": "你好"}], + "max_tokens": 512 + }' + +# Anthropic 兼容流式并设置停止序列 +curl -X POST http://localhost:8000/v1/messages \ + -H "Content-Type: application/json" \ + -d '{ + "model": "astrai", + "messages": [{"role": "user", "content": "写个故事"}], + "max_tokens": 500, + "stream": true, + "stop_sequences": ["结束"] + }' + # 健康检查 curl http://localhost:8000/health ``` diff --git a/assets/docs/introduction.md b/assets/docs/introduction.md index b7420a3..c350574 100644 --- a/assets/docs/introduction.md +++ b/assets/docs/introduction.md @@ -262,25 +262,60 @@ curl -X POST http://localhost:8000/v1/chat/completions \ The server uses Server-Sent Events (SSE) with content type `text/event-stream`. -### Simple Generation Endpoint +### Health Check -For basic text generation without chat format: + +### Anthropic-Compatible Endpoint + +The server also provides an Anthropic-compatible endpoint at `/v1/messages`: ```bash -curl -X POST "http://localhost:8000/generate?query=Hello&max_len=1000" \ - -H "Content-Type: application/json" -``` - -Or with conversation history: - -```bash -curl -X POST "http://localhost:8000/generate" \ +curl -X POST http://localhost:8000/v1/messages \ -H "Content-Type: application/json" \ -d '{ - "query": "What is AI?", - "history": [["Hello", "Hi there!"], ["How are you?", "I'm doing well"]], - "temperature": 0.8, - "max_len": 2048 + "model": "astrai", + "system": "You are a helpful assistant.", + "messages": [{"role": "user", "content": "Hello, how are you?"}], + "max_tokens": 2048 + }' +``` + +Response: +```json +{ + "id": "msg_abc123...", + "type": "message", + "role": "assistant", + "model": "astrai", + "content": [{"type": "text", "text": "Hello! I am doing well..."}], + "stop_reason": "end_turn", + "stop_sequence": null, + "usage": {"input_tokens": 20, "output_tokens": 15} +} +``` + +Streaming: +```bash +curl -X POST http://localhost:8000/v1/messages \ + -H "Content-Type: application/json" \ + -d '{ + "model": "astrai", + "system": "You are a helpful assistant.", + "messages": [{"role": "user", "content": "Write a short poem"}], + "max_tokens": 500, + "stream": true + }' +``` + +Supports `stop_sequences` for early termination: +```bash +curl -X POST http://localhost:8000/v1/messages \ + -H "Content-Type: application/json" \ + -d '{ + "model": "astrai", + "messages": [{"role": "user", "content": "Write a story"}], + "max_tokens": 500, + "stop_sequences": ["The end", "THE END"] }' ``` diff --git a/astrai/inference/server.py b/astrai/inference/server.py index 8cc3cd7..e1e7d37 100644 --- a/astrai/inference/server.py +++ b/astrai/inference/server.py @@ -1,5 +1,5 @@ """ -OpenAI-compatible chat completion server backed by continuous-batching inference. +OpenAI / Anthropic-compatible chat completion server backed by continuous-batching inference. """ import json @@ -61,6 +61,25 @@ class ChatCompletionRequest(BaseModel): user: Optional[str] = None +class AnthropicMessage(BaseModel): + role: str + content: Union[str, List[Dict[str, Any]]] + + +class MessagesRequest(BaseModel): + """Anthropic Messages API request body.""" + + model: str = "astrai" + max_tokens: int = Field(default=1024, ge=1) + messages: List[AnthropicMessage] + system: Optional[str] = None + temperature: Optional[float] = Field(default=1.0, ge=0.0, le=2.0) + top_p: Optional[float] = Field(default=1.0, ge=0.0, le=1.0) + top_k: Optional[int] = Field(default=50, ge=1) + stream: Optional[bool] = False + stop_sequences: Optional[List[str]] = None + + def configure_server( device: str = "cuda", dtype: torch.dtype = torch.bfloat16, @@ -264,55 +283,183 @@ async def chat_completion(request: ChatCompletionRequest): } -@app.post("/generate") -async def generate( - query: str, - history: Optional[List[List[str]]] = None, - temperature: float = 0.8, - top_p: float = 0.95, - top_k: int = 50, - max_len: int = 2048, - stream: bool = False, -): - """Legacy non-OpenAI generation endpoint (kept for backward compat).""" +def _make_anthropic_sse(event: str, data: Dict[str, Any]) -> str: + return f"event: {event}\ndata: {json.dumps(data, ensure_ascii=False)}\n\n" + + +def _check_stop_sequence(text: str, stop_sequences: List[str]) -> Optional[str]: + for seq in stop_sequences: + if seq and seq in text: + return seq + return None + + +def _extract_text_content(content: Union[str, List[Dict[str, Any]]]) -> str: + if isinstance(content, str): + return content + if isinstance(content, list): + for block in content: + if isinstance(block, dict) and block.get("type") == "text": + return block.get("text", "") + return "" + + +def _build_anthropic_messages( + messages: List[AnthropicMessage], system: Optional[str] +) -> List[Dict[str, str]]: + result: List[Dict[str, str]] = [] + if system: + result.append({"role": "system", "content": system}) + for m in messages: + content = _extract_text_content(m.content) + if content: + result.append({"role": m.role, "content": content}) + return result + + +@app.post("/v1/messages") +async def create_message(request: MessagesRequest): + """Anthropic-compatible Messages API endpoint (streaming + non-streaming).""" engine = _get_engine() + resp_id = f"msg_{uuid.uuid4().hex[:24]}" + model = request.model - messages = [] - if history: - for h in history: - if len(h) >= 2: - messages.append({"role": "user", "content": h[0]}) - messages.append({"role": "assistant", "content": h[1]}) - messages.append({"role": "user", "content": query}) + chat_messages = _build_anthropic_messages(request.messages, request.system) + prompt = engine.tokenizer.apply_chat_template(chat_messages, tokenize=False) + prompt_tokens = len(engine.tokenizer.encode(prompt)) - prompt = engine.tokenizer.apply_chat_template(messages, tokenize=False) + stop_sequences = request.stop_sequences or [] - if stream: + if request.stream: agen = engine.generate_async( prompt=prompt, - max_tokens=max_len, - temperature=temperature, - top_p=top_p, - top_k=top_k, + max_tokens=request.max_tokens, + temperature=request.temperature, + top_p=request.top_p, + top_k=request.top_k, ) - async def text_stream(): - async for token in agen: - yield token + "\n" + async def event_stream(): + yield _make_anthropic_sse( + "message_start", + { + "type": "message_start", + "message": { + "id": resp_id, + "type": "message", + "role": "assistant", + "model": model, + "content": [], + "usage": {"input_tokens": prompt_tokens}, + }, + }, + ) - return StreamingResponse(text_stream(), media_type="text/plain") - else: - chunks = [] - for token in engine.generate( - prompt=prompt, - stream=True, - max_tokens=max_len, - temperature=temperature, - top_p=top_p, - top_k=top_k, - ): - chunks.append(token) - return {"response": "".join(chunks)} + yield _make_anthropic_sse( + "content_block_start", + { + "type": "content_block_start", + "index": 0, + "content_block": {"type": "text", "text": ""}, + }, + ) + + completion_tokens = 0 + accumulated = "" + stopped_seq: Optional[str] = None + async for token in agen: + accumulated += token + completion_tokens += 1 + + matched = _check_stop_sequence(accumulated, stop_sequences) + if matched: + text = accumulated[: accumulated.rfind(matched)] + stopped_seq = matched + if text: + yield _make_anthropic_sse( + "content_block_delta", + { + "type": "content_block_delta", + "index": 0, + "delta": {"type": "text_delta", "text": text}, + }, + ) + break + + yield _make_anthropic_sse( + "content_block_delta", + { + "type": "content_block_delta", + "index": 0, + "delta": {"type": "text_delta", "text": token}, + }, + ) + + yield _make_anthropic_sse( + "content_block_stop", + {"type": "content_block_stop", "index": 0}, + ) + + stop_reason = "stop_sequence" if stopped_seq else "end_turn" + yield _make_anthropic_sse( + "message_delta", + { + "type": "message_delta", + "delta": {"stop_reason": stop_reason, "stop_sequence": stopped_seq}, + "usage": {"output_tokens": completion_tokens}, + }, + ) + + yield _make_anthropic_sse( + "message_stop", + {"type": "message_stop"}, + ) + + return StreamingResponse( + event_stream(), + media_type="text/event-stream", + headers={"Cache-Control": "no-cache", "Connection": "keep-alive"}, + ) + + completion_tokens = 0 + chunks: List[str] = [] + agen = engine.generate_async( + prompt=prompt, + max_tokens=request.max_tokens, + temperature=request.temperature, + top_p=request.top_p, + top_k=request.top_k, + ) + stopped_seq: Optional[str] = None + accumulated = "" + async for token in agen: + chunks.append(token) + completion_tokens += 1 + accumulated += token + matched = _check_stop_sequence(accumulated, stop_sequences) + if matched: + stopped_seq = matched + break + + content = "".join(chunks) + if stopped_seq: + idx = content.rfind(stopped_seq) + if idx != -1: + content = content[:idx] + + return { + "id": resp_id, + "type": "message", + "role": "assistant", + "model": model, + "content": [{"type": "text", "text": content}], + "stop_reason": "stop_sequence" if stopped_seq else "end_turn", + "stop_sequence": stopped_seq, + "usage": { + "input_tokens": prompt_tokens, + "output_tokens": completion_tokens, + }, + } def run_server( diff --git a/tests/inference/test_server.py b/tests/inference/test_server.py index a65d828..63fbc67 100644 --- a/tests/inference/test_server.py +++ b/tests/inference/test_server.py @@ -1,7 +1,5 @@ """Unit tests for the inference HTTP server.""" -from unittest.mock import MagicMock - import pytest @@ -24,52 +22,6 @@ def test_health_with_model(client, loaded_model): assert data["model_loaded"] is True -def test_generate_non_stream(client, loaded_model, monkeypatch): - """POST /generate with stream=false should return JSON response.""" - response = client.post( - "/generate", - params={ - "query": "Hello", - "temperature": 0.8, - "top_p": 0.95, - "top_k": 50, - "max_len": 100, - "stream": False, - }, - ) - assert response.status_code == 200 - data = response.json() - assert "response" in data - - -def test_generate_stream(client, loaded_model, monkeypatch): - """POST /generate with stream=true should return plain text stream.""" - - async def async_gen(): - yield "chunk1" - yield "chunk2" - - mock_engine = loaded_model - mock_engine.generate_async.return_value = async_gen() - monkeypatch.setattr("astrai.inference.server._state.engine", mock_engine) - response = client.post( - "/generate", - params={ - "query": "Hello", - "temperature": 0.8, - "top_p": 0.95, - "top_k": 50, - "max_len": 100, - "stream": True, - }, - headers={"Accept": "text/plain"}, - ) - assert response.status_code == 200 - content = response.content.decode("utf-8") - assert "chunk1" in content - assert "chunk2" in content - - def test_chat_completions_non_stream(client, loaded_model, monkeypatch): """POST /v1/chat/completions with stream=false returns OpenAI-style JSON.""" @@ -125,17 +77,87 @@ def test_chat_completions_stream(client, loaded_model, monkeypatch): assert any("[DONE]" in line for line in lines) -def test_generate_with_history(client, loaded_model, monkeypatch): - """POST /generate with history parameter.""" +def test_messages_non_stream(client, loaded_model, monkeypatch): + """POST /v1/messages with stream=false returns Anthropic-style JSON.""" + + async def async_gen(): + yield "Assistant reply" + + mock_engine = loaded_model + mock_engine.generate_async.return_value = async_gen() + monkeypatch.setattr("astrai.inference.server._state.engine", mock_engine) response = client.post( - "/generate", - params={ - "query": "Hi", - "history": [["user1", "assistant1"], ["user2", "assistant2"]], + "/v1/messages", + json={ + "messages": [{"role": "user", "content": "Hello"}], + "temperature": 0.8, + "max_tokens": 100, "stream": False, }, ) assert response.status_code == 200 + data = response.json() + assert data["type"] == "message" + assert data["role"] == "assistant" + assert len(data["content"]) == 1 + assert data["content"][0]["type"] == "text" + assert "usage" in data + assert "input_tokens" in data["usage"] + + +def test_messages_stream(client, loaded_model, monkeypatch): + """POST /v1/messages with stream=true returns Anthropic SSE stream.""" + + async def async_gen(): + yield "cumulative1" + yield "cumulative2" + + mock_engine = loaded_model + mock_engine.generate_async.return_value = async_gen() + monkeypatch.setattr("astrai.inference.server._state.engine", mock_engine) + response = client.post( + "/v1/messages", + json={ + "messages": [{"role": "user", "content": "Hello"}], + "temperature": 0.8, + "max_tokens": 100, + "stream": True, + }, + headers={"Accept": "text/event-stream"}, + ) + assert response.status_code == 200 + content = response.content.decode("utf-8") + assert "message_start" in content + assert "content_block_start" in content + assert "content_block_delta" in content + assert "cumulative1" in content + assert "cumulative2" in content + assert "content_block_stop" in content + assert "message_delta" in content + assert "message_stop" in content + + +def test_messages_with_system(client, loaded_model, monkeypatch): + """POST /v1/messages with system prompt.""" + + async def async_gen(): + yield "Reply" + + mock_engine = loaded_model + mock_engine.generate_async.return_value = async_gen() + monkeypatch.setattr("astrai.inference.server._state.engine", mock_engine) + response = client.post( + "/v1/messages", + json={ + "messages": [{"role": "user", "content": "Hello"}], + "system": "You are a helpful assistant.", + "max_tokens": 100, + "stream": False, + }, + ) + assert response.status_code == 200 + data = response.json() + assert data["type"] == "message" if __name__ == "__main__":