test: 增加测试单元
This commit is contained in:
parent
20e73c05e0
commit
edb09a7ac1
456
assets/API.md
456
assets/API.md
|
|
@ -18,7 +18,7 @@
|
|||
```json
|
||||
{
|
||||
"success": true,
|
||||
"message": "注册成功",
|
||||
"message": "Registration successful",
|
||||
"data": {
|
||||
"id": 1,
|
||||
"username": "string"
|
||||
|
|
@ -41,19 +41,31 @@
|
|||
```json
|
||||
{
|
||||
"success": true,
|
||||
"message": "登录成功",
|
||||
"message": "Login successful",
|
||||
"data": {
|
||||
"access_token": "eyJ...",
|
||||
"token_type": "bearer",
|
||||
"user": {
|
||||
"id": 1,
|
||||
"username": "string",
|
||||
"role": "user"
|
||||
"email": "user@example.com",
|
||||
"role": "user",
|
||||
"permission_level": 1,
|
||||
"workspace_path": null,
|
||||
"is_active": true
|
||||
}
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
**用户权限级别:**
|
||||
| 级别 | 名称 | 说明 |
|
||||
|------|------|------|
|
||||
| 1 | READ_ONLY | 只读权限 |
|
||||
| 2 | WRITE | 写入权限 |
|
||||
| 3 | EXECUTE | 执行权限 |
|
||||
| 4 | ADMIN | 管理员权限 |
|
||||
|
||||
### POST /api/auth/logout
|
||||
用户登出
|
||||
|
||||
|
|
@ -63,7 +75,7 @@
|
|||
```json
|
||||
{
|
||||
"success": true,
|
||||
"message": "登出成功"
|
||||
"message": "Logout successful"
|
||||
}
|
||||
```
|
||||
|
||||
|
|
@ -81,11 +93,39 @@
|
|||
"username": "string",
|
||||
"email": "user@example.com",
|
||||
"role": "user",
|
||||
"is_active": true
|
||||
"permission_level": 1,
|
||||
"workspace_path": null,
|
||||
"is_active": true,
|
||||
"created_at": "2024-01-01T00:00:00"
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
### GET /api/auth/users
|
||||
获取所有用户(管理员专用)
|
||||
|
||||
**请求头:** `Authorization: Bearer <token>`
|
||||
|
||||
**响应:**
|
||||
```json
|
||||
{
|
||||
"success": true,
|
||||
"data": {
|
||||
"users": [...]
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
### PUT /api/auth/users/{user_id}
|
||||
更新用户权限(管理员专用)
|
||||
|
||||
**请求体:**
|
||||
```json
|
||||
{
|
||||
"permission_level": 2
|
||||
}
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## 会话 `/api/conversations`
|
||||
|
|
@ -123,10 +163,11 @@
|
|||
{
|
||||
"project_id": "string (可选)",
|
||||
"title": "新会话",
|
||||
"model": "glm-5",
|
||||
"system_prompt": "string (可选)",
|
||||
"temperature": 1.0,
|
||||
"max_tokens": 65536,
|
||||
"model": "deepseek-chat",
|
||||
"provider_id": 1,
|
||||
"system_prompt": "You are a helpful assistant. (可选)",
|
||||
"temperature": 0.7,
|
||||
"max_tokens": 2000,
|
||||
"thinking_enabled": false
|
||||
}
|
||||
```
|
||||
|
|
@ -139,9 +180,15 @@
|
|||
"data": {
|
||||
"id": "conv_xxx",
|
||||
"user_id": 1,
|
||||
"provider_id": 1,
|
||||
"title": "新会话",
|
||||
"model": "glm-5",
|
||||
...
|
||||
"model": "deepseek-chat",
|
||||
"system_prompt": "You are a helpful assistant.",
|
||||
"temperature": 0.7,
|
||||
"max_tokens": 2000,
|
||||
"thinking_enabled": false,
|
||||
"created_at": "2024-01-01T00:00:00",
|
||||
"updated_at": "2024-01-01T00:00:00"
|
||||
}
|
||||
}
|
||||
```
|
||||
|
|
@ -149,32 +196,92 @@
|
|||
### GET /api/conversations/{id}
|
||||
获取会话详情
|
||||
|
||||
**路径参数:**
|
||||
- `id`: 会话ID
|
||||
**路径参数:** `id`: 会话ID
|
||||
|
||||
**请求头:** `Authorization: Bearer <token>`
|
||||
|
||||
### PUT /api/conversations/{id}
|
||||
更新会话
|
||||
|
||||
**路径参数:** `id`: 会话ID
|
||||
|
||||
**请求头:** `Authorization: Bearer <token>`
|
||||
|
||||
**请求体:**
|
||||
```json
|
||||
{
|
||||
"title": "新标题",
|
||||
"model": "gpt-4",
|
||||
"provider_id": 1,
|
||||
"system_prompt": "You are...",
|
||||
"temperature": 0.8,
|
||||
"max_tokens": 4000,
|
||||
"thinking_enabled": true
|
||||
}
|
||||
```
|
||||
|
||||
### DELETE /api/conversations/{id}
|
||||
删除会话
|
||||
|
||||
**路径参数:** `id`: 会话ID
|
||||
|
||||
**请求头:** `Authorization: Bearer <token>`
|
||||
|
||||
---
|
||||
|
||||
## 消息 `/api/messages`
|
||||
|
||||
### GET /api/messages/{conversation_id}
|
||||
### GET /api/messages/
|
||||
获取消息列表
|
||||
|
||||
**路径参数:**
|
||||
- `conversation_id`: 会话ID
|
||||
|
||||
**查询参数:**
|
||||
- `conversation_id`: 会话ID(必需)
|
||||
- `limit` (可选): 返回数量,默认100
|
||||
|
||||
**请求头:** `Authorization: Bearer <token>`
|
||||
|
||||
**响应:**
|
||||
```json
|
||||
{
|
||||
"success": true,
|
||||
"data": {
|
||||
"messages": [
|
||||
{
|
||||
"id": "msg_xxx",
|
||||
"conversation_id": "conv_xxx",
|
||||
"role": "user",
|
||||
"content": "用户消息",
|
||||
"text": "用户消息",
|
||||
"attachments": [],
|
||||
"process_steps": [],
|
||||
"token_count": 10,
|
||||
"usage": null,
|
||||
"created_at": "2024-01-01T00:00:00"
|
||||
},
|
||||
{
|
||||
"id": "msg_yyy",
|
||||
"conversation_id": "conv_xxx",
|
||||
"role": "assistant",
|
||||
"content": "AI 回复文本内容",
|
||||
"text": "AI 回复文本内容",
|
||||
"attachments": [],
|
||||
"process_steps": [
|
||||
{"id": "step-0", "index": 0, "type": "thinking", "content": "让我思考..."},
|
||||
{"id": "step-1", "index": 1, "type": "text", "content": "根据搜索结果..."},
|
||||
{"id": "step-2", "index": 2, "type": "tool_call", "id_ref": "call_xxx", "name": "web_search", "arguments": "..."},
|
||||
{"id": "step-3", "index": 3, "type": "tool_result", "id_ref": "call_xxx", "name": "web_search", "content": "...", "success": true}
|
||||
],
|
||||
"token_count": 100,
|
||||
"usage": {"prompt_tokens": 50, "completion_tokens": 50, "total_tokens": 100},
|
||||
"created_at": "2024-01-01T00:00:01"
|
||||
}
|
||||
],
|
||||
"title": "会话标题",
|
||||
"first_message": "用户的第一条消息..."
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
### POST /api/messages/
|
||||
发送消息(非流式)
|
||||
|
||||
|
|
@ -185,7 +292,7 @@
|
|||
{
|
||||
"conversation_id": "conv_xxx",
|
||||
"content": "用户消息",
|
||||
"tools_enabled": true
|
||||
"thinking_enabled": false
|
||||
}
|
||||
```
|
||||
|
||||
|
|
@ -201,20 +308,182 @@
|
|||
```
|
||||
|
||||
### POST /api/messages/stream
|
||||
发送消息(流式响应)
|
||||
发送消息(流式响应 - SSE)
|
||||
|
||||
使用 Server-Sent Events (SSE) 返回流式响应。
|
||||
|
||||
**事件类型:**
|
||||
- `text`: 文本增量
|
||||
- `tool_call`: 工具调用
|
||||
- `tool_result`: 工具结果
|
||||
- `done`: 完成
|
||||
- `error`: 错误
|
||||
**请求头:** `Authorization: Bearer <token>`
|
||||
|
||||
### DELETE /api/messages/{id}
|
||||
**请求体:**
|
||||
```json
|
||||
{
|
||||
"conversation_id": "conv_xxx",
|
||||
"content": "用户消息",
|
||||
"thinking_enabled": true,
|
||||
"enabled_tools": ["web_search", "file_read", "python_execute"]
|
||||
}
|
||||
```
|
||||
|
||||
**SSE 事件类型:**
|
||||
|
||||
#### process_step
|
||||
结构化步骤事件(渲染顺序的唯一数据源)
|
||||
|
||||
```json
|
||||
event: process_step
|
||||
data: {"step": {"id": "step-0", "index": 0, "type": "thinking", "content": "让我思考一下..."}}
|
||||
|
||||
event: process_step
|
||||
data: {"step": {"id": "step-1", "index": 1, "type": "text", "content": "以下是搜索结果:"}}
|
||||
|
||||
event: process_step
|
||||
data: {"step": {"id": "step-2", "index": 2, "type": "tool_call", "id_ref": "call_abc", "name": "web_search", "arguments": "{\"query\": \"...\"}"}}
|
||||
|
||||
event: process_step
|
||||
data: {"step": {"id": "step-3", "index": 3, "type": "tool_result", "id_ref": "call_abc", "name": "web_search", "content": "{...}", "success": true}}
|
||||
```
|
||||
|
||||
**步骤类型说明:**
|
||||
|
||||
| type | 说明 | 额外字段 |
|
||||
|------|------|---------|
|
||||
| `thinking` | 模型思考过程 | `content` |
|
||||
| `text` | 文本回复 | `content` |
|
||||
| `tool_call` | 工具调用 | `id_ref`, `name`, `arguments` |
|
||||
| `tool_result` | 工具执行结果 | `id_ref`, `name`, `content`, `success` |
|
||||
|
||||
#### done
|
||||
响应完成
|
||||
|
||||
```json
|
||||
event: done
|
||||
data: {"message_id": "msg_xxx", "token_count": 150, "usage": {"prompt_tokens": 100, "completion_tokens": 50, "total_tokens": 150}}
|
||||
```
|
||||
|
||||
#### error
|
||||
错误信息
|
||||
|
||||
```json
|
||||
event: error
|
||||
data: {"content": "错误信息描述"}
|
||||
```
|
||||
|
||||
### DELETE /api/messages/{message_id}
|
||||
删除消息
|
||||
|
||||
**路径参数:** `message_id`: 消息ID
|
||||
|
||||
**请求头:** `Authorization: Bearer <token>`
|
||||
|
||||
---
|
||||
|
||||
## LLM 提供商 `/api/providers`
|
||||
|
||||
### GET /api/providers/
|
||||
获取用户的 LLM 提供商列表
|
||||
|
||||
**请求头:** `Authorization: Bearer <token>`
|
||||
|
||||
**响应:**
|
||||
```json
|
||||
{
|
||||
"success": true,
|
||||
"data": {
|
||||
"providers": [
|
||||
{
|
||||
"id": 1,
|
||||
"user_id": 1,
|
||||
"name": "DeepSeek",
|
||||
"provider_type": "openai",
|
||||
"base_url": "https://api.deepseek.com/v1",
|
||||
"default_model": "deepseek-chat",
|
||||
"max_tokens": 8192,
|
||||
"is_default": true,
|
||||
"enabled": true,
|
||||
"created_at": "2024-01-01T00:00:00",
|
||||
"updated_at": "2024-01-01T00:00:00"
|
||||
}
|
||||
],
|
||||
"total": 1
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
### POST /api/providers/
|
||||
创建 LLM 提供商
|
||||
|
||||
**请求头:** `Authorization: Bearer <token>`
|
||||
|
||||
**请求体:**
|
||||
```json
|
||||
{
|
||||
"name": "DeepSeek",
|
||||
"provider_type": "openai",
|
||||
"base_url": "https://api.deepseek.com/v1",
|
||||
"api_key": "sk-xxxx",
|
||||
"default_model": "deepseek-chat",
|
||||
"is_default": true
|
||||
}
|
||||
```
|
||||
|
||||
**provider_type 可选值:**
|
||||
- `openai` - OpenAI/DeepSeek/GLM 兼容 API
|
||||
- `anthropic` - Anthropic Claude API
|
||||
|
||||
### GET /api/providers/{provider_id}
|
||||
获取提供商详情
|
||||
|
||||
**路径参数:** `provider_id`: 提供商ID
|
||||
|
||||
**请求头:** `Authorization: Bearer <token>`
|
||||
|
||||
### PUT /api/providers/{provider_id}
|
||||
更新提供商
|
||||
|
||||
**路径参数:** `provider_id`: 提供商ID
|
||||
|
||||
**请求头:** `Authorization: Bearer <token>`
|
||||
|
||||
**请求体:**
|
||||
```json
|
||||
{
|
||||
"name": "新名称",
|
||||
"base_url": "https://api.example.com/v1",
|
||||
"api_key": "sk-yyyy",
|
||||
"default_model": "gpt-4",
|
||||
"max_tokens": 16384,
|
||||
"is_default": false,
|
||||
"enabled": true
|
||||
}
|
||||
```
|
||||
|
||||
### DELETE /api/providers/{provider_id}
|
||||
删除提供商
|
||||
|
||||
**路径参数:** `provider_id`: 提供商ID
|
||||
|
||||
**请求头:** `Authorization: Bearer <token>`
|
||||
|
||||
### POST /api/providers/{provider_id}/test
|
||||
测试提供商连接
|
||||
|
||||
**路径参数:** `provider_id`: 提供商ID
|
||||
|
||||
**请求头:** `Authorization: Bearer <token>`
|
||||
|
||||
**响应:**
|
||||
```json
|
||||
{
|
||||
"success": true,
|
||||
"message": "HTTP 200: ...",
|
||||
"data": {
|
||||
"status_code": 200,
|
||||
"success": true,
|
||||
"response_body": "..."
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## 工具 `/api/tools`
|
||||
|
|
@ -223,7 +492,7 @@
|
|||
获取可用工具列表
|
||||
|
||||
**查询参数:**
|
||||
- `category` (可选): 工具分类
|
||||
- `category` (可选): 工具分类(code/file/shell/crawler/data)
|
||||
|
||||
**请求头:** `Authorization: Bearer <token>`
|
||||
|
||||
|
|
@ -232,12 +501,21 @@
|
|||
{
|
||||
"success": true,
|
||||
"data": {
|
||||
"tools": [...],
|
||||
"tools": [
|
||||
{
|
||||
"name": "python_execute",
|
||||
"description": "Execute Python code",
|
||||
"category": "code",
|
||||
"parameters": {...}
|
||||
},
|
||||
...
|
||||
],
|
||||
"categorized": {
|
||||
"crawler": [...],
|
||||
"code": [...],
|
||||
"data": [...],
|
||||
"weather": [...]
|
||||
"file": [...],
|
||||
"shell": [...],
|
||||
"crawler": [...],
|
||||
"data": [...]
|
||||
},
|
||||
"total": 11
|
||||
}
|
||||
|
|
@ -247,9 +525,39 @@
|
|||
### GET /api/tools/{name}
|
||||
获取工具详情
|
||||
|
||||
**路径参数:** `name`: 工具名称
|
||||
|
||||
**请求头:** `Authorization: Bearer <token>`
|
||||
|
||||
**响应:**
|
||||
```json
|
||||
{
|
||||
"success": true,
|
||||
"data": {
|
||||
"name": "web_search",
|
||||
"description": "Search the web using DuckDuckGo",
|
||||
"category": "crawler",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"query": {
|
||||
"type": "string",
|
||||
"description": "Search query"
|
||||
}
|
||||
},
|
||||
"required": ["query"]
|
||||
}
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
### POST /api/tools/{name}/execute
|
||||
手动执行工具
|
||||
|
||||
**路径参数:** `name`: 工具名称
|
||||
|
||||
**请求头:** `Authorization: Bearer <token>`
|
||||
|
||||
**请求体:**
|
||||
```json
|
||||
{
|
||||
|
|
@ -257,3 +565,89 @@
|
|||
"arg2": "value2"
|
||||
}
|
||||
```
|
||||
|
||||
**响应:**
|
||||
```json
|
||||
{
|
||||
"success": true,
|
||||
"data": {
|
||||
"result": "..."
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## 公共端点
|
||||
|
||||
### GET /api/health
|
||||
健康检查
|
||||
|
||||
**响应:**
|
||||
```json
|
||||
{
|
||||
"status": "ok",
|
||||
"message": "Luxx API is running"
|
||||
}
|
||||
```
|
||||
|
||||
### GET /api/
|
||||
服务信息
|
||||
|
||||
**响应:**
|
||||
```json
|
||||
{
|
||||
"name": "Luxx",
|
||||
"version": "1.0.0",
|
||||
"description": "AI Chat API"
|
||||
}
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## 工具说明
|
||||
|
||||
### 内置工具
|
||||
|
||||
#### 代码执行 (code)
|
||||
| 工具 | 功能 | 权限 |
|
||||
|------|------|------|
|
||||
| `python_execute` | 执行 Python 代码 | EXECUTE |
|
||||
| `python_eval` | 计算表达式 | EXECUTE |
|
||||
|
||||
#### 文件操作 (file)
|
||||
| 工具 | 功能 | 权限 |
|
||||
|------|------|------|
|
||||
| `file_read` | 读取文件内容 | READ_ONLY |
|
||||
| `file_write` | 写入文件内容 | WRITE |
|
||||
| `file_list` | 列出目录内容 | READ_ONLY |
|
||||
| `file_exists` | 检查文件是否存在 | READ_ONLY |
|
||||
| `file_grep` | 正则搜索文件 | READ_ONLY |
|
||||
|
||||
#### Shell 命令 (shell)
|
||||
| 工具 | 功能 | 权限 |
|
||||
|------|------|------|
|
||||
| `shell_execute` | 执行 Shell 命令 | EXECUTE |
|
||||
|
||||
#### 网页爬虫 (crawler)
|
||||
| 工具 | 功能 | 权限 |
|
||||
|------|------|------|
|
||||
| `web_search` | DuckDuckGo 搜索 | READ_ONLY |
|
||||
| `web_fetch` | 网页抓取 | READ_ONLY |
|
||||
| `batch_fetch` | 批量并发抓取 | READ_ONLY |
|
||||
|
||||
#### 数据处理 (data)
|
||||
| 工具 | 功能 | 权限 |
|
||||
|------|------|------|
|
||||
| `process_data` | JSON 转换、格式化 | READ_ONLY |
|
||||
|
||||
### 权限检查
|
||||
|
||||
工具执行时自动检查用户权限:
|
||||
|
||||
```
|
||||
工具要求的权限 <= 用户拥有的权限 → 允许执行
|
||||
工具要求的权限 > 用户拥有的权限 → 返回错误
|
||||
```
|
||||
|
||||
用户通过 `/api/auth/users/{user_id}` 接口设置权限级别。
|
||||
|
|
|
|||
|
|
@ -29,14 +29,22 @@ luxx/
|
|||
│ ├── providers.py # LLM 提供商管理
|
||||
│ └── tools.py # 工具管理
|
||||
├── services/ # 服务层
|
||||
│ ├── __init__.py # 服务导出
|
||||
│ ├── chat.py # 聊天服务门面
|
||||
│ ├── agentic_loop.py # Agentic Loop 执行器
|
||||
│ ├── stream_context.py # 流式状态管理
|
||||
│ ├── llm_response.py # LLM 响应解析器
|
||||
│ ├── process_result.py# 处理结果
|
||||
│ └── llm_client.py # LLM 客户端
|
||||
│ ├── llm_response.py # LLM 响应数据类
|
||||
│ ├── process_result.py # [已移除]
|
||||
│ ├── task.py # 任务系统 (Task/TaskGraph/TaskService)
|
||||
│ ├── llm_client.py # LLM 客户端
|
||||
│ └── llm_adapters/ # LLM API 适配器
|
||||
│ ├── __init__.py # 适配器导出
|
||||
│ ├── base.py # ProviderAdapter 基类
|
||||
│ ├── openai_adapter.py # OpenAI/DeepSeek/GLM 适配器
|
||||
│ └── anthropic_adapter.py # Anthropic Claude 适配器
|
||||
├── tools/ # 工具系统
|
||||
│ ├── core.py # 核心类 (ToolRegistry, ToolDefinition, ToolResult)
|
||||
│ ├── __init__.py # 工具注册入口
|
||||
│ ├── core.py # 核心类 (ToolRegistry, ToolDefinition, ToolResult, ToolContext)
|
||||
│ ├── factory.py # @tool 装饰器
|
||||
│ ├── executor.py # 工具执行器 (缓存/并行)
|
||||
│ ├── services.py # 工具服务层
|
||||
|
|
@ -44,8 +52,11 @@ luxx/
|
|||
│ ├── __init__.py # 工具注册入口
|
||||
│ ├── code.py # 代码执行 (python_execute, python_eval)
|
||||
│ ├── crawler.py # 网页爬虫 (web_search, web_fetch, batch_fetch)
|
||||
│ └── data.py # 数据处理 (process_data)
|
||||
│ ├── data.py # 数据处理 (process_data)
|
||||
│ ├── file.py # 文件操作 (file_read, file_write, file_list, file_exists, file_grep)
|
||||
│ └── shell.py # Shell 命令 (shell_execute)
|
||||
└── utils/ # 工具函数
|
||||
├── __init__.py
|
||||
└── helpers.py # 密码哈希、ID生成、响应封装
|
||||
|
||||
run.py # 应用入口文件
|
||||
|
|
@ -77,15 +88,36 @@ password: admin123
|
|||
app:
|
||||
secret_key: ${APP_SECRET_KEY}
|
||||
debug: true
|
||||
host: 0.0.0.0
|
||||
port: 8000
|
||||
|
||||
database:
|
||||
type: sqlite
|
||||
url: sqlite:///./chat.db
|
||||
|
||||
workspace:
|
||||
root: ./workspaces # 用户工作空间根目录
|
||||
auto_create: true # 自动创建用户目录
|
||||
|
||||
llm:
|
||||
provider: deepseek
|
||||
api_key: ${DEEPSEEK_API_KEY}
|
||||
api_url: https://api.deepseek.com/v1
|
||||
|
||||
tools:
|
||||
enable_cache: true
|
||||
cache_ttl: 300
|
||||
max_workers: 4
|
||||
max_iterations: 10
|
||||
|
||||
logging:
|
||||
level: INFO
|
||||
```
|
||||
|
||||
**工作空间隔离机制:**
|
||||
- 每个用户的工作空间路径基于 `user_id` 的 SHA256 哈希值
|
||||
- 格式:`{workspace_root}/{hash_of_user_id}`
|
||||
- 所有文件操作必须在用户工作空间内,防止路径穿越攻击
|
||||
```
|
||||
|
||||
### 3. 数据库 (`database.py`)
|
||||
|
|
@ -103,6 +135,8 @@ erDiagram
|
|||
string email UK
|
||||
string password_hash
|
||||
string role
|
||||
int permission_level "1=READ_ONLY, 2=WRITE, 3=EXECUTE, 4=ADMIN"
|
||||
string workspace_path "用户工作空间路径"
|
||||
boolean is_active
|
||||
datetime created_at
|
||||
}
|
||||
|
|
@ -164,6 +198,14 @@ erDiagram
|
|||
CONVERSATION ||--o{ MESSAGE : "has"
|
||||
```
|
||||
|
||||
**用户权限级别 (permission_level):**
|
||||
| 级别 | 名称 | 说明 |
|
||||
|------|------|------|
|
||||
| 1 | READ_ONLY | 只读权限 |
|
||||
| 2 | WRITE | 写入权限(文件写入) |
|
||||
| 3 | EXECUTE | 执行权限(代码执行、Shell命令) |
|
||||
| 4 | ADMIN | 管理员权限 |
|
||||
|
||||
### Message Content JSON 结构
|
||||
|
||||
`content` 字段统一使用 JSON 格式存储:
|
||||
|
|
@ -183,8 +225,6 @@ erDiagram
|
|||
|
||||
```json
|
||||
{
|
||||
"text": "AI 回复的文本内容",
|
||||
"tool_calls": [...],
|
||||
"steps": [
|
||||
{"id": "step-0", "index": 0, "type": "thinking", "content": "..."},
|
||||
{"id": "step-1", "index": 1, "type": "text", "content": "..."},
|
||||
|
|
@ -194,7 +234,9 @@ erDiagram
|
|||
}
|
||||
```
|
||||
|
||||
`steps` 字段是**渲染顺序的唯一数据源**,按 `index` 顺序排列。thinking、text、tool_call、tool_result 可以在多轮迭代中穿插出现。
|
||||
`steps` 字段是**唯一数据源**,按 `index` 顺序排列。thinking、text、tool_call、tool_result 可以在多轮迭代中穿插出现。
|
||||
|
||||
**注意**:`text` 和 `content` 字段通过解析 `steps` 中所有 `type: "text"` 的内容动态计算得出。
|
||||
|
||||
### 5. 工具系统
|
||||
|
||||
|
|
@ -206,9 +248,25 @@ classDiagram
|
|||
+dict parameters
|
||||
+Callable handler
|
||||
+str category
|
||||
+CommandPermission required_permission
|
||||
+to_openai_format() dict
|
||||
}
|
||||
|
||||
class ToolContext {
|
||||
+int user_id
|
||||
+str username
|
||||
+str workspace
|
||||
+int user_permission_level
|
||||
}
|
||||
|
||||
class CommandPermission {
|
||||
<<enumeration>>
|
||||
READ_ONLY = 1
|
||||
WRITE = 2
|
||||
EXECUTE = 3
|
||||
ADMIN = 4
|
||||
}
|
||||
|
||||
class ToolResult {
|
||||
+bool success
|
||||
+Any data
|
||||
|
|
@ -224,7 +282,7 @@ classDiagram
|
|||
+get(name) ToolDefinition?
|
||||
+list_all() List~dict~
|
||||
+list_by_category(category) List~dict~
|
||||
+execute(name, arguments) dict
|
||||
+execute(name, arguments, context) dict
|
||||
+remove(name) bool
|
||||
}
|
||||
|
||||
|
|
@ -243,14 +301,51 @@ classDiagram
|
|||
|
||||
#### 内置工具
|
||||
|
||||
| 工具 | 功能 | 说明 |
|
||||
**代码执行 (code.py)**
|
||||
|
||||
| 工具 | 功能 | 权限 |
|
||||
|------|------|------|
|
||||
| `python_execute` | 执行 Python 代码 | 支持 print 输出、变量访问 |
|
||||
| `python_eval` | 计算表达式 | 快速求值 |
|
||||
| `web_search` | DuckDuckGo HTML | DuckDuckGo HTML 搜索 |
|
||||
| `web_fetch` | 网页抓取 | httpx + BeautifulSoup,支持 text/links/structured |
|
||||
| `batch_fetch` | 批量抓取 | 并发获取多个页面 |
|
||||
| `process_data` | 数据处理 | JSON 转换、格式化等 |
|
||||
| `python_execute` | 执行 Python 代码 | EXECUTE |
|
||||
| `python_eval` | 计算表达式 | EXECUTE |
|
||||
|
||||
**文件操作 (file.py)**
|
||||
|
||||
| 工具 | 功能 | 权限 |
|
||||
|------|------|------|
|
||||
| `file_read` | 读取文件内容 | READ_ONLY |
|
||||
| `file_write` | 写入文件内容 | WRITE |
|
||||
| `file_list` | 列出目录内容 | READ_ONLY |
|
||||
| `file_exists` | 检查文件是否存在 | READ_ONLY |
|
||||
| `file_grep` | 正则搜索文件内容 | READ_ONLY |
|
||||
|
||||
**Shell 命令 (shell.py)**
|
||||
|
||||
| 工具 | 功能 | 权限 |
|
||||
|------|------|------|
|
||||
| `shell_execute` | 执行 Shell 命令 | EXECUTE |
|
||||
|
||||
**网页爬虫 (crawler.py)**
|
||||
|
||||
| 工具 | 功能 | 权限 |
|
||||
|------|------|------|
|
||||
| `web_search` | DuckDuckGo HTML 搜索 | READ_ONLY |
|
||||
| `web_fetch` | 网页抓取 | READ_ONLY |
|
||||
| `batch_fetch` | 批量并发抓取 | READ_ONLY |
|
||||
|
||||
**数据处理 (data.py)**
|
||||
|
||||
| 工具 | 功能 | 权限 |
|
||||
|------|------|------|
|
||||
| `process_data` | JSON 转换、格式化 | READ_ONLY |
|
||||
|
||||
#### 权限检查机制
|
||||
|
||||
工具执行时自动检查用户权限:
|
||||
|
||||
```
|
||||
工具要求的权限 <= 用户拥有的权限 → 允许执行
|
||||
工具要求的权限 > 用户拥有的权限 → 拒绝执行
|
||||
```
|
||||
|
||||
#### 工具开发规范
|
||||
|
||||
|
|
@ -312,26 +407,82 @@ ToolExecutor 返回结果
|
|||
|
||||
### 6. 服务层
|
||||
|
||||
#### LLMResponseParser (`services/llm_response.py`)
|
||||
统一解析器,兼容多种 LLM API 格式:
|
||||
- **OpenAI**: `delta.content`, `delta.tool_calls`
|
||||
- **DeepSeek**: `delta.content`, `delta.reasoning_content`
|
||||
- **Anthropic**: `content_block` 类型事件
|
||||
- **MiniMax**: `<|im_start|>thinking...<|im_end|>` 标签
|
||||
#### LLM 适配器 (`services/llm_adapters/`)
|
||||
|
||||
适配器模式统一处理不同 LLM API 格式:
|
||||
|
||||
```mermaid
|
||||
classDiagram
|
||||
class ProviderAdapter {
|
||||
<<abstract>>
|
||||
+str provider_type
|
||||
+build_request() tuple
|
||||
+parse_stream_chunk() AsyncGenerator
|
||||
+parse_response() Dict
|
||||
+supports_thinking() bool
|
||||
+supports_tools() bool
|
||||
}
|
||||
|
||||
class OpenAIAdapter {
|
||||
+str provider_type = "openai"
|
||||
+build_request() tuple
|
||||
+parse_stream_chunk() AsyncGenerator
|
||||
+parse_response() Dict
|
||||
+supports_tools() bool
|
||||
}
|
||||
|
||||
class AnthropicAdapter {
|
||||
+str provider_type = "anthropic"
|
||||
+build_request() tuple
|
||||
+parse_stream_chunk() AsyncGenerator
|
||||
+parse_response() Dict
|
||||
+supports_thinking() bool
|
||||
+supports_tools() bool
|
||||
}
|
||||
|
||||
ProviderAdapter <|-- OpenAIAdapter
|
||||
ProviderAdapter <|-- AnthropicAdapter
|
||||
```
|
||||
|
||||
**支持的功能对比:**
|
||||
|
||||
| 适配器 | 工具调用 | Thinking/Reasoning | 流式响应 |
|
||||
|--------|----------|-------------------|----------|
|
||||
| OpenAI | ✅ | ✅ (DeepSeek) | ✅ |
|
||||
| Anthropic | ✅ | ✅ | ✅ |
|
||||
|
||||
#### LLM 响应数据类 (`services/llm_response.py`)
|
||||
|
||||
```python
|
||||
from luxx.services.llm_response import llm_parser
|
||||
class StepType:
|
||||
"""步骤类型常量"""
|
||||
THINKING = "thinking"
|
||||
TEXT = "text"
|
||||
TOOL_CALL = "tool_call"
|
||||
TOOL_RESULT = "tool_result"
|
||||
|
||||
# 解析 OpenAI 格式
|
||||
parsed = llm_parser.parse_openai(delta)
|
||||
|
||||
# 解析 Anthropic 格式
|
||||
parsed = llm_parser.parse_anthropic(chunk)
|
||||
@dataclass
|
||||
class Step:
|
||||
"""单个步骤 - 用于存储和传输"""
|
||||
id: str
|
||||
index: int
|
||||
type: str # thinking, text, tool_call, tool_result
|
||||
content: str = ""
|
||||
name: str = "" # tool_call/tool_result
|
||||
arguments: str = "" # tool_call
|
||||
id_ref: str = "" # tool_result
|
||||
success: bool = True
|
||||
|
||||
# 返回 ParsedDelta
|
||||
parsed.thinking # 思考内容
|
||||
parsed.text # 文本内容
|
||||
parsed.tool_calls # 工具调用
|
||||
|
||||
@dataclass
|
||||
class ParsedDelta:
|
||||
"""LLM 流式响应增量"""
|
||||
thinking: str = "" # 思考内容(增量)
|
||||
text: str = "" # 文本内容(增量)
|
||||
tool_call: Optional[Dict] = None # 单个工具调用
|
||||
usage: Dict[str, int] = {} # Token 用量
|
||||
is_complete: bool = False
|
||||
```
|
||||
|
||||
#### ChatService (`services/chat.py`)
|
||||
|
|
@ -340,30 +491,101 @@ parsed.tool_calls # 工具调用
|
|||
- 流式 SSE 响应
|
||||
- 工具调用编排(并行执行)
|
||||
- 消息历史管理
|
||||
- 自动重试机制
|
||||
- Token 用量追踪
|
||||
- 工作空间上下文传递
|
||||
|
||||
#### AgenticLoop (`services/agentic_loop.py`)
|
||||
执行 Agentic Loop 的核心循环:
|
||||
- 调用 LLM 获取响应
|
||||
- 使用 LLMResponseParser 解析响应
|
||||
- 调用 LLM 获取响应(流式)
|
||||
- 解析 ParsedDelta,更新步骤状态
|
||||
- 管理 thinking/text/tool_call/tool_result 步骤
|
||||
- 工具并行执行
|
||||
- 最大迭代次数:10
|
||||
|
||||
```python
|
||||
# 执行流程
|
||||
async for delta in llm.stream_call(...):
|
||||
events = self._process_delta(delta, context, total_usage)
|
||||
yield from events
|
||||
|
||||
# 工具调用时
|
||||
tool_results = self.tool_executor.process_tool_calls_parallel(...)
|
||||
messages.append({"role": "assistant", ...})
|
||||
messages.extend(tool_results)
|
||||
```
|
||||
|
||||
#### StreamContext (`services/stream_context.py`)
|
||||
流式状态管理:
|
||||
- 追踪当前步骤类型和索引
|
||||
- 累积 thinking 和 text 内容
|
||||
- 管理 tool_calls 列表
|
||||
- 管理 tool_calls 列表和 tool_results
|
||||
- 生成 SSE 事件
|
||||
- 构建完整消息内容
|
||||
|
||||
#### LLMClient (`services/llm_client.py`)
|
||||
LLM API 客户端:
|
||||
- 多提供商:DeepSeek、GLM、OpenAI
|
||||
- 多提供商:OpenAI、DeepSeek、Anthropic
|
||||
- 自动适配器选择
|
||||
- 流式/同步调用
|
||||
- 错误处理和重试
|
||||
- Token 计数
|
||||
|
||||
### 7. 任务系统 (`services/task.py`)
|
||||
|
||||
用于自主任务执行和依赖管理:
|
||||
|
||||
```mermaid
|
||||
classDiagram
|
||||
class Task {
|
||||
+str id
|
||||
+str name
|
||||
+str goal
|
||||
+TaskStatus status
|
||||
+List~Step~ steps
|
||||
+List~Task~ subtasks
|
||||
}
|
||||
|
||||
class Step {
|
||||
+str id
|
||||
+str name
|
||||
+List~str~ depends_on
|
||||
+StepStatus status
|
||||
}
|
||||
|
||||
class TaskGraph {
|
||||
+topological_sort() List~Step~
|
||||
+get_ready_steps() List~Step~
|
||||
+detect_cycles() List~List~str~~
|
||||
+validate() tuple
|
||||
}
|
||||
|
||||
class TaskService {
|
||||
+create_task() Task
|
||||
+get_task() Task
|
||||
+update_task_status() Task
|
||||
+add_steps() List~Step~
|
||||
+build_graph() TaskGraph
|
||||
}
|
||||
|
||||
Task "1" o-- "*" Step
|
||||
Task "1" o-- "*" Task
|
||||
TaskService ..> TaskGraph
|
||||
```
|
||||
|
||||
**任务状态 (TaskStatus):**
|
||||
- `PENDING` - 待处理
|
||||
- `READY` - 就绪
|
||||
- `RUNNING` - 运行中
|
||||
- `BLOCK` - 阻塞
|
||||
- `TERMINATED` - 已终止
|
||||
|
||||
**步骤状态 (StepStatus):**
|
||||
- `PENDING` - 待执行
|
||||
- `RUNNING` - 执行中
|
||||
- `COMPLETED` - 已完成
|
||||
- `FAILED` - 失败
|
||||
- `SKIPPED` - 跳过
|
||||
|
||||
### 7. 认证系统 (`routes/auth.py`)
|
||||
- JWT Bearer Token
|
||||
- Bcrypt 密码哈希
|
||||
|
|
@ -481,6 +703,10 @@ database:
|
|||
type: sqlite
|
||||
url: sqlite:///./chat.db
|
||||
|
||||
workspace:
|
||||
root: ./workspaces # 用户工作空间根目录
|
||||
auto_create: true # 自动创建用户工作空间
|
||||
|
||||
llm:
|
||||
provider: deepseek
|
||||
api_key: ${DEEPSEEK_API_KEY}
|
||||
|
|
@ -491,6 +717,9 @@ tools:
|
|||
cache_ttl: 300
|
||||
max_workers: 4
|
||||
max_iterations: 10
|
||||
|
||||
logging:
|
||||
level: INFO
|
||||
```
|
||||
|
||||
## 环境变量
|
||||
|
|
@ -501,6 +730,26 @@ tools:
|
|||
| `DEEPSEEK_API_KEY` | DeepSeek API | `sk-xxxx` |
|
||||
| `DATABASE_URL` | 数据库连接 | `sqlite:///./chat.db` |
|
||||
|
||||
## LLM 适配器配置
|
||||
|
||||
### OpenAI 兼容 (DeepSeek/GLM 等)
|
||||
|
||||
```yaml
|
||||
llm:
|
||||
provider: openai
|
||||
api_key: ${API_KEY}
|
||||
api_url: https://api.deepseek.com/v1 # 或其他兼容端点
|
||||
```
|
||||
|
||||
### Anthropic Claude
|
||||
|
||||
```yaml
|
||||
llm:
|
||||
provider: anthropic
|
||||
api_key: ${ANTHROPIC_API_KEY}
|
||||
api_url: https://api.anthropic.com/v1
|
||||
```
|
||||
|
||||
## 项目结构说明
|
||||
|
||||
### 入口文件
|
||||
|
|
@ -530,5 +779,21 @@ ToolExecutor 支持结果缓存:
|
|||
|
||||
1. 实时返回 thinking_content(模型思考过程)
|
||||
2. 实时返回 text 增量更新
|
||||
3. 工具调用串行执行,结果批量返回
|
||||
3. 工具调用并行执行,结果批量返回
|
||||
4. 最终 `done` 事件包含完整 message_id 和 token 用量
|
||||
|
||||
### 工作空间隔离
|
||||
|
||||
每个用户的工作空间完全隔离:
|
||||
- 用户目录基于 user_id 的 SHA256 哈希生成
|
||||
- 所有文件操作强制在用户工作空间内
|
||||
- 支持权限级别控制文件操作能力
|
||||
|
||||
### MessageBuilder
|
||||
|
||||
用于构建发送给 LLM 的消息列表:
|
||||
- `add_system()` - 添加系统消息
|
||||
- `add_user()` - 添加用户消息(JSON 格式)
|
||||
- `add_assistant()` - 添加助手消息
|
||||
- `add_tool_result()` - 添加工具结果消息
|
||||
- `extract_text()` - 从 JSON 内容中提取文本
|
||||
|
|
|
|||
|
|
@ -154,8 +154,6 @@ class Message(Base):
|
|||
|
||||
**Assistant 消息:**
|
||||
{
|
||||
"text": "AI 回复的文本内容",
|
||||
"tool_calls": [...], // 遗留的扁平结构
|
||||
"steps": [ // 有序步骤,用于渲染(主要数据源)
|
||||
{"id": "step-0", "index": 0, "type": "thinking", "content": "..."},
|
||||
{"id": "step-1", "index": 1, "type": "text", "content": "..."},
|
||||
|
|
@ -163,6 +161,8 @@ class Message(Base):
|
|||
{"id": "step-3", "index": 3, "type": "tool_result", "id_ref": "call_xxx", "name": "...", "content": "..."}
|
||||
]
|
||||
}
|
||||
|
||||
注意:to_dict() 返回时会从 steps 动态计算 text 和 content 字段。
|
||||
"""
|
||||
__tablename__ = "messages"
|
||||
|
||||
|
|
@ -204,20 +204,22 @@ class Message(Base):
|
|||
result["content"] = self.content
|
||||
result["text"] = self.content
|
||||
result["attachments"] = []
|
||||
result["tool_calls"] = []
|
||||
result["process_steps"] = []
|
||||
return result
|
||||
|
||||
# Extract common fields
|
||||
result["text"] = content_obj.get("text", "")
|
||||
result["attachments"] = content_obj.get("attachments", [])
|
||||
result["tool_calls"] = content_obj.get("tool_calls", [])
|
||||
|
||||
# Extract steps as process_steps for frontend rendering
|
||||
result["process_steps"] = content_obj.get("steps", [])
|
||||
steps = content_obj.get("steps", [])
|
||||
result["process_steps"] = steps
|
||||
|
||||
# For backward compatibility
|
||||
if "content" not in result:
|
||||
result["content"] = result["text"]
|
||||
# Extract text from steps (concatenate all text type steps)
|
||||
text_content = "".join(
|
||||
s.get("content", "") for s in steps
|
||||
if s.get("type") == "text"
|
||||
)
|
||||
result["text"] = text_content
|
||||
result["content"] = text_content # Alias for convenience
|
||||
|
||||
# Extract attachments
|
||||
result["attachments"] = content_obj.get("attachments", [])
|
||||
|
||||
return result
|
||||
|
|
|
|||
|
|
@ -0,0 +1,214 @@
|
|||
"""Repository layer for data access - follows Repository Pattern
|
||||
|
||||
This module separates data access logic from business logic, following
|
||||
the Dependency Inversion Principle (DIP) from SOLID principles.
|
||||
"""
|
||||
import json
|
||||
import logging
|
||||
from typing import List, Optional, Dict, Any
|
||||
from contextlib import contextmanager
|
||||
|
||||
from luxx.database import SessionLocal
|
||||
from luxx.models import Message, LLMProvider, Conversation, User
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class RepositoryError(Exception):
|
||||
"""Base exception for repository errors"""
|
||||
pass
|
||||
|
||||
|
||||
class UnitOfWork:
|
||||
"""Unit of Work pattern for managing database sessions
|
||||
|
||||
Usage:
|
||||
with UnitOfWork() as uow:
|
||||
messages = uow.messages.get_by_conversation(conv_id)
|
||||
uow.commit()
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
self._session = None
|
||||
|
||||
def __enter__(self):
|
||||
self._session = SessionLocal()
|
||||
return self
|
||||
|
||||
def __exit__(self, exc_type, exc_val, exc_tb):
|
||||
if exc_type is not None:
|
||||
self._session.rollback()
|
||||
return False
|
||||
return True
|
||||
|
||||
@property
|
||||
def session(self):
|
||||
if self._session is None:
|
||||
raise RepositoryError("UnitOfWork not started. Use 'with UnitOfWork()'")
|
||||
return self._session
|
||||
|
||||
def commit(self):
|
||||
"""Commit the current transaction"""
|
||||
try:
|
||||
self._session.commit()
|
||||
except Exception as e:
|
||||
self._session.rollback()
|
||||
raise RepositoryError(f"Commit failed: {e}") from e
|
||||
|
||||
def rollback(self):
|
||||
"""Rollback the current transaction"""
|
||||
self._session.rollback()
|
||||
|
||||
@property
|
||||
def messages(self) -> "MessageRepository":
|
||||
return MessageRepository(self._session)
|
||||
|
||||
@property
|
||||
def providers(self) -> "ProviderRepository":
|
||||
return ProviderRepository(self._session)
|
||||
|
||||
@property
|
||||
def conversations(self) -> "ConversationRepository":
|
||||
return ConversationRepository(self._session)
|
||||
|
||||
|
||||
class BaseRepository:
|
||||
"""Base repository with common operations"""
|
||||
|
||||
def __init__(self, session):
|
||||
self._session = session
|
||||
|
||||
|
||||
class MessageRepository(BaseRepository):
|
||||
"""Repository for Message data access"""
|
||||
|
||||
def get_by_id(self, msg_id: str) -> Optional[Message]:
|
||||
"""Get message by ID"""
|
||||
return self._session.query(Message).filter(Message.id == msg_id).first()
|
||||
|
||||
def get_by_conversation(self, conversation_id: str) -> List[Message]:
|
||||
"""Get all messages for a conversation, ordered by creation time"""
|
||||
return self._session.query(Message).filter(
|
||||
Message.conversation_id == conversation_id
|
||||
).order_by(Message.created_at).all()
|
||||
|
||||
def create(
|
||||
self,
|
||||
msg_id: str,
|
||||
conversation_id: str,
|
||||
role: str,
|
||||
content: Dict[str, Any],
|
||||
token_count: int = 0,
|
||||
usage: Dict[str, Any] = None
|
||||
) -> Message:
|
||||
"""Create a new message"""
|
||||
msg = Message(
|
||||
id=msg_id,
|
||||
conversation_id=conversation_id,
|
||||
role=role,
|
||||
content=json.dumps(content, ensure_ascii=False),
|
||||
token_count=token_count,
|
||||
usage=json.dumps(usage) if usage else None
|
||||
)
|
||||
self._session.add(msg)
|
||||
return msg
|
||||
|
||||
def delete(self, msg_id: str) -> bool:
|
||||
"""Delete a message by ID"""
|
||||
msg = self.get_by_id(msg_id)
|
||||
if msg:
|
||||
self._session.delete(msg)
|
||||
return True
|
||||
return False
|
||||
|
||||
def delete_by_conversation(self, conversation_id: str) -> int:
|
||||
"""Delete all messages for a conversation"""
|
||||
count = self._session.query(Message).filter(
|
||||
Message.conversation_id == conversation_id
|
||||
).delete()
|
||||
return count
|
||||
|
||||
|
||||
class ProviderRepository(BaseRepository):
|
||||
"""Repository for LLM Provider data access"""
|
||||
|
||||
def get_by_id(self, provider_id: int) -> Optional[LLMProvider]:
|
||||
"""Get provider by ID"""
|
||||
return self._session.query(LLMProvider).filter(LLMProvider.id == provider_id).first()
|
||||
|
||||
def get_by_user(self, user_id: int) -> List[LLMProvider]:
|
||||
"""Get all providers for a user"""
|
||||
return self._session.query(LLMProvider).filter(
|
||||
LLMProvider.user_id == user_id
|
||||
).all()
|
||||
|
||||
def get_default(self, user_id: int) -> Optional[LLMProvider]:
|
||||
"""Get the default provider for a user"""
|
||||
return self._session.query(LLMProvider).filter(
|
||||
LLMProvider.user_id == user_id,
|
||||
LLMProvider.is_default == True
|
||||
).first()
|
||||
|
||||
def create(self, **kwargs) -> LLMProvider:
|
||||
"""Create a new provider"""
|
||||
provider = LLMProvider(**kwargs)
|
||||
self._session.add(provider)
|
||||
return provider
|
||||
|
||||
def update(self, provider: LLMProvider) -> LLMProvider:
|
||||
"""Update an existing provider"""
|
||||
self._session.add(provider)
|
||||
return provider
|
||||
|
||||
def delete(self, provider_id: int) -> bool:
|
||||
"""Delete a provider by ID"""
|
||||
provider = self.get_by_id(provider_id)
|
||||
if provider:
|
||||
self._session.delete(provider)
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
class ConversationRepository(BaseRepository):
|
||||
"""Repository for Conversation data access"""
|
||||
|
||||
def get_by_id(self, conversation_id: str) -> Optional[Conversation]:
|
||||
"""Get conversation by ID"""
|
||||
return self._session.query(Conversation).filter(
|
||||
Conversation.id == conversation_id
|
||||
).first()
|
||||
|
||||
def get_by_user(self, user_id: int, limit: int = 50) -> List[Conversation]:
|
||||
"""Get recent conversations for a user"""
|
||||
return self._session.query(Conversation).filter(
|
||||
Conversation.user_id == user_id
|
||||
).order_by(Conversation.updated_at.desc()).limit(limit).all()
|
||||
|
||||
def create(self, **kwargs) -> Conversation:
|
||||
"""Create a new conversation"""
|
||||
conversation = Conversation(**kwargs)
|
||||
self._session.add(conversation)
|
||||
return conversation
|
||||
|
||||
def update(self, conversation: Conversation) -> Conversation:
|
||||
"""Update an existing conversation"""
|
||||
self._session.add(conversation)
|
||||
return conversation
|
||||
|
||||
def delete(self, conversation_id: str) -> bool:
|
||||
"""Delete a conversation and its messages (cascade)"""
|
||||
conversation = self.get_by_id(conversation_id)
|
||||
if conversation:
|
||||
self._session.delete(conversation)
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
# Factory function for creating services
|
||||
def create_message_repository() -> MessageRepository:
|
||||
"""Factory for MessageRepository with its own session"""
|
||||
return MessageRepository(SessionLocal())
|
||||
|
||||
def create_provider_repository() -> ProviderRepository:
|
||||
"""Factory for ProviderRepository with its own session"""
|
||||
return ProviderRepository(SessionLocal())
|
||||
|
|
@ -1,4 +1,4 @@
|
|||
"""Services module"""
|
||||
from luxx.services.llm_client import LLMClient
|
||||
from luxx.services.llm_response import ParsedDelta, LLMResponse
|
||||
from luxx.services.llm_response import ParsedDelta, Step, StepType
|
||||
from luxx.services.chat import ChatService, create_chat_service
|
||||
|
|
|
|||
|
|
@ -1,12 +1,6 @@
|
|||
"""AgenticLoop - Executes the Agentic Loop: LLM + Tools iteration.
|
||||
|
||||
The loop:
|
||||
1. Call LLM with messages and tools
|
||||
2. Check for tool calls in response
|
||||
3. Execute tools in parallel
|
||||
4. Add results to messages
|
||||
5. Repeat (max 10 iterations)
|
||||
6. Return final response
|
||||
This module follows the Single Responsibility Principle.
|
||||
"""
|
||||
import uuid
|
||||
import logging
|
||||
|
|
@ -14,21 +8,17 @@ from typing import List, Dict, AsyncGenerator
|
|||
|
||||
from luxx.tools.executor import ToolExecutor
|
||||
from luxx.services.llm_client import LLMClient
|
||||
from luxx.services.stream_context import StreamContext, _sse_event
|
||||
from luxx.services.process_result import ProcessResult
|
||||
from luxx.services.stream_context import StreamState, StreamRenderer, StepType
|
||||
from luxx.services.llm_response import ParsedDelta
|
||||
from luxx.services.events import sse_event
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Maximum iterations to prevent infinite loops
|
||||
MAX_ITERATIONS = 10
|
||||
|
||||
|
||||
class AgenticLoop:
|
||||
"""Executes the Agentic Loop: LLM + Tools iteration.
|
||||
|
||||
Supports multiple LLM Providers, auto-adapts response format.
|
||||
"""
|
||||
"""Executes the agentic loop (LLM + Tools iteration)"""
|
||||
|
||||
def __init__(self, tool_executor: ToolExecutor):
|
||||
self.tool_executor = tool_executor
|
||||
|
|
@ -42,20 +32,15 @@ class AgenticLoop:
|
|||
temperature: float,
|
||||
max_tokens: int,
|
||||
thinking_enabled: bool,
|
||||
context: 'StreamContext',
|
||||
context: StreamState,
|
||||
tool_context: dict = None
|
||||
) -> AsyncGenerator[str, None]:
|
||||
"""Execute the agentic loop.
|
||||
|
||||
Yields SSE events for each step.
|
||||
"""
|
||||
total_usage = {"prompt_tokens": 0, "completion_tokens": 0, "total_tokens": 0}
|
||||
|
||||
for iteration in range(MAX_ITERATIONS):
|
||||
context.reset()
|
||||
has_error = False
|
||||
|
||||
# Stream LLM response - now yields ParsedDelta directly
|
||||
async for delta in llm.stream_call(
|
||||
model=model,
|
||||
messages=messages,
|
||||
|
|
@ -64,65 +49,39 @@ class AgenticLoop:
|
|||
max_tokens=max_tokens,
|
||||
thinking_enabled=thinking_enabled
|
||||
):
|
||||
# Process parsed delta
|
||||
result = self._process_delta(delta, context, total_usage)
|
||||
|
||||
# Yield events
|
||||
for event in result.events:
|
||||
events = self._process_delta(delta, context, total_usage)
|
||||
for event in events:
|
||||
yield event
|
||||
|
||||
# Check for errors
|
||||
if result.has_error:
|
||||
if not delta.has_content() and not delta.is_complete:
|
||||
has_error = True
|
||||
break
|
||||
|
||||
# If error occurred, break the loop
|
||||
if has_error:
|
||||
break
|
||||
|
||||
# Finalize current step
|
||||
if delta.is_complete:
|
||||
for event in self._flush_remaining(context):
|
||||
yield event
|
||||
|
||||
context.finalize_step()
|
||||
|
||||
# Check for tool calls
|
||||
if context.tool_calls_list:
|
||||
# Execute tools and yield events
|
||||
for event in self._execute_tools(context, messages, tool_context):
|
||||
yield event
|
||||
continue
|
||||
|
||||
# No tools - complete
|
||||
for event in self._complete(context, total_usage):
|
||||
yield event
|
||||
return
|
||||
|
||||
# Max iterations exceeded or error occurred
|
||||
if not has_error:
|
||||
yield _sse_event("error", {"content": "Exceeded maximum tool call iterations"})
|
||||
yield sse_event("error", {"content": "Exceeded maximum tool call iterations"})
|
||||
|
||||
def _process_delta(
|
||||
self,
|
||||
delta: ParsedDelta,
|
||||
ctx: 'StreamContext',
|
||||
total_usage: dict
|
||||
) -> ProcessResult:
|
||||
"""Process ParsedDelta from adapter, return result with events and flags.
|
||||
def _process_delta(self, delta: ParsedDelta, ctx: StreamState, total_usage: dict) -> List[str]:
|
||||
"""Process a single delta from the LLM stream"""
|
||||
events = []
|
||||
|
||||
Args:
|
||||
delta: ParsedDelta from LLM adapter
|
||||
ctx: StreamContext for state management
|
||||
total_usage: Accumulated token usage
|
||||
|
||||
Returns:
|
||||
ProcessResult with events and flags
|
||||
"""
|
||||
result = ProcessResult()
|
||||
|
||||
# Check for error (empty delta with no content)
|
||||
if not delta.has_content() and not delta.is_complete:
|
||||
# Empty delta, possibly an error
|
||||
return result
|
||||
|
||||
# Update usage
|
||||
if delta.usage:
|
||||
total_usage.update({
|
||||
"prompt_tokens": delta.usage.get("prompt_tokens", 0),
|
||||
|
|
@ -130,73 +89,49 @@ class AgenticLoop:
|
|||
"total_tokens": delta.usage.get("total_tokens", 0)
|
||||
})
|
||||
|
||||
# Process thinking content (incremental)
|
||||
if delta.thinking:
|
||||
logger.debug(f"Processing thinking: {delta.thinking[:50]}...")
|
||||
ctx.full_thinking += delta.thinking # Accumulate incremental content
|
||||
if not ctx.current_step_id or ctx.current_step_type != "thinking":
|
||||
ctx.start_step("thinking")
|
||||
result.add_event(_sse_event("process_step", {
|
||||
"step": {
|
||||
"id": ctx.current_step_id,
|
||||
"index": ctx.current_step_idx,
|
||||
"type": "thinking",
|
||||
"content": ctx.full_thinking
|
||||
}
|
||||
}))
|
||||
result.set_content()
|
||||
if delta.content:
|
||||
result = ctx.process_content(delta.content)
|
||||
if result["should_emit"]:
|
||||
if result["thinking"]:
|
||||
ctx.full_thinking += result["thinking"]
|
||||
ctx.start_step(StepType.THINKING)
|
||||
events.append(StreamRenderer.render_thinking(ctx))
|
||||
|
||||
# Process text content (incremental)
|
||||
if delta.text:
|
||||
ctx.full_content += delta.text # Accumulate incremental content
|
||||
if not ctx.current_step_id or ctx.current_step_type != "text":
|
||||
ctx.start_step("text")
|
||||
result.add_event(_sse_event("process_step", {
|
||||
"step": {
|
||||
"id": ctx.current_step_id,
|
||||
"index": ctx.current_step_idx,
|
||||
"type": "text",
|
||||
"content": ctx.full_content
|
||||
}
|
||||
}))
|
||||
result.set_content()
|
||||
if result["text"]:
|
||||
ctx.full_content += result["text"]
|
||||
ctx.start_step(StepType.TEXT)
|
||||
events.append(StreamRenderer.render_text(ctx))
|
||||
|
||||
# Process tool calls
|
||||
if delta.tool_calls:
|
||||
for tc in delta.tool_calls:
|
||||
ctx.accumulate_tool_call(tc)
|
||||
result.set_tool_calls()
|
||||
ctx._thinking_buf = ""
|
||||
ctx._text_buf = ""
|
||||
|
||||
return result
|
||||
if delta.has_tool_call():
|
||||
ctx.accumulate_tool_call(delta.tool_call)
|
||||
|
||||
def _execute_tools(self, ctx: 'StreamContext', messages: list,
|
||||
tool_context: dict = None) -> List[str]:
|
||||
"""Execute tools and return list of events."""
|
||||
return events
|
||||
|
||||
def _execute_tools(self, ctx: StreamState, messages: list, tool_context: dict = None) -> List[str]:
|
||||
"""Execute tools and add results to messages"""
|
||||
events = []
|
||||
|
||||
# Emit tool call steps
|
||||
for event in ctx.emit_tool_calls():
|
||||
for event in StreamRenderer.render_tool_calls(ctx):
|
||||
events.append(event)
|
||||
|
||||
# Execute in parallel
|
||||
tool_results = self.tool_executor.process_tool_calls_parallel(
|
||||
ctx.tool_calls_list, tool_context or {}
|
||||
)
|
||||
|
||||
# Get tool call IDs for result linking
|
||||
tool_ids = [tc.get("id") for tc in ctx.tool_calls_list]
|
||||
tool_step_ids = [
|
||||
s["id"] for s in ctx.all_steps
|
||||
if s["type"] == "tool_call" and s.get("id_ref") in tool_ids
|
||||
s.id for s in ctx.all_steps
|
||||
if s.type == StepType.TOOL_CALL and s.id_ref in tool_ids
|
||||
]
|
||||
|
||||
# Emit tool result steps
|
||||
for i, (tr, tc) in enumerate(zip(tool_results, ctx.tool_calls_list)):
|
||||
ref_id = tool_step_ids[i] if i < len(tool_step_ids) else f"step-{len(ctx.all_steps) - len(tool_results) + i}"
|
||||
_, event = ctx.emit_tool_result(tr, ref_id)
|
||||
_, event = StreamRenderer.render_tool_result(ctx, tr, ref_id)
|
||||
events.append(event)
|
||||
|
||||
# Prepare for next iteration
|
||||
messages.append({
|
||||
"role": "assistant",
|
||||
"content": ctx.full_content or "",
|
||||
|
|
@ -206,15 +141,31 @@ class AgenticLoop:
|
|||
|
||||
return events
|
||||
|
||||
def _complete(self, ctx: 'StreamContext', total_usage: dict) -> List[str]:
|
||||
"""Complete the loop and return list of events."""
|
||||
def _flush_remaining(self, ctx: StreamState) -> List[str]:
|
||||
"""Flush remaining buffers on complete"""
|
||||
events = []
|
||||
thinking, text = ctx.flush()
|
||||
if thinking:
|
||||
ctx.full_thinking += thinking
|
||||
ctx.start_step(StepType.THINKING)
|
||||
events.append(StreamRenderer.render_thinking(ctx))
|
||||
ctx.finalize_step()
|
||||
if text:
|
||||
ctx.full_content += text
|
||||
ctx.start_step(StepType.TEXT)
|
||||
events.append(StreamRenderer.render_text(ctx))
|
||||
ctx.finalize_step()
|
||||
return events
|
||||
|
||||
def _complete(self, ctx: StreamState, total_usage: dict) -> List[str]:
|
||||
"""Signal completion of the agentic loop"""
|
||||
token_count = total_usage.get("completion_tokens") or len(ctx.full_content) // 4
|
||||
msg_id = str(uuid.uuid4())
|
||||
logger.info(f"[TOKEN] usage={total_usage}, count={token_count}")
|
||||
|
||||
ctx.set_completion(msg_id, token_count, total_usage)
|
||||
|
||||
return [_sse_event("done", {
|
||||
return [sse_event("done", {
|
||||
"message_id": msg_id,
|
||||
"token_count": token_count,
|
||||
"usage": total_usage
|
||||
|
|
|
|||
|
|
@ -1,10 +1,13 @@
|
|||
"""Chat service module with Agentic Loop pattern.
|
||||
|
||||
This module provides the core chat service that orchestrates:
|
||||
- StreamContext: Manages streaming state transitions
|
||||
This module follows SOLID principles:
|
||||
- Single Responsibility: Each class has one job
|
||||
- Dependency Inversion: Depend on abstractions (repositories) not concretions
|
||||
|
||||
Components:
|
||||
- MessageBuilder: Constructs message lists
|
||||
- AgenticLoop: Executes the agentic loop (LLM + tools iteration)
|
||||
- ChatService: Core chat service facade
|
||||
- ChatService: Core chat service (orchestration only)
|
||||
- ProviderFactory: Creates LLM clients (Dependency Injection)
|
||||
"""
|
||||
import json
|
||||
import logging
|
||||
|
|
@ -15,9 +18,10 @@ from typing import List, Dict, Any, AsyncGenerator
|
|||
from luxx.tools.executor import ToolExecutor
|
||||
from luxx.tools.core import registry
|
||||
from luxx.services.llm_client import LLMClient
|
||||
from luxx.services.stream_context import StreamContext
|
||||
from luxx.services.stream_context import StreamState
|
||||
from luxx.services.agentic_loop import AgenticLoop
|
||||
from luxx.config import config
|
||||
from luxx.services.events import sse_event
|
||||
from luxx.repositories import UnitOfWork
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
|
@ -79,65 +83,91 @@ class MessageBuilder:
|
|||
return content
|
||||
|
||||
|
||||
# ============== Factory Function ==============
|
||||
# ============== LLM Provider Factory ==============
|
||||
|
||||
def get_llm_client(conversation=None) -> tuple:
|
||||
"""Get LLM client based on conversation provider. Returns (client, max_tokens)"""
|
||||
from luxx.models import LLMProvider
|
||||
from luxx.database import SessionLocal
|
||||
class LLMProviderFactory:
|
||||
"""Factory for creating LLM clients - follows Dependency Injection
|
||||
|
||||
max_tokens = None
|
||||
This separates the creation of dependencies from their usage,
|
||||
following the Dependency Inversion Principle.
|
||||
"""
|
||||
|
||||
if conversation and conversation.provider_id:
|
||||
db = SessionLocal()
|
||||
try:
|
||||
provider = db.query(LLMProvider).filter(
|
||||
LLMProvider.id == conversation.provider_id
|
||||
).first()
|
||||
if provider:
|
||||
max_tokens = provider.max_tokens
|
||||
@staticmethod
|
||||
def create_client(
|
||||
provider=None,
|
||||
api_key: str = None,
|
||||
api_url: str = None,
|
||||
model: str = None
|
||||
) -> tuple:
|
||||
"""Create LLM client from provider or direct parameters
|
||||
|
||||
Args:
|
||||
provider: LLMProvider model instance (optional)
|
||||
api_key: Direct API key (used if no provider)
|
||||
api_url: Direct API URL (used if no provider)
|
||||
model: Direct model name (used if no provider)
|
||||
|
||||
Returns:
|
||||
tuple: (LLMClient, max_tokens)
|
||||
"""
|
||||
if provider is not None:
|
||||
client = LLMClient(
|
||||
api_key=provider.api_key,
|
||||
api_url=provider.base_url,
|
||||
model=provider.default_model
|
||||
model=provider.default_model,
|
||||
provider_type=provider.provider_type
|
||||
)
|
||||
return client, max_tokens
|
||||
finally:
|
||||
db.close()
|
||||
return client, provider.max_tokens
|
||||
|
||||
return LLMClient(), max_tokens
|
||||
# Fallback to direct parameters
|
||||
client = LLMClient(
|
||||
api_key=api_key,
|
||||
api_url=api_url,
|
||||
model=model
|
||||
)
|
||||
return client, None
|
||||
|
||||
|
||||
# ============== Chat Service ==============
|
||||
|
||||
class ChatService:
|
||||
"""Core chat service with Agentic Loop support."""
|
||||
"""Core chat service with Agentic Loop support.
|
||||
|
||||
def __init__(self):
|
||||
self.tool_executor = ToolExecutor()
|
||||
self.agentic_loop = AgenticLoop(self.tool_executor)
|
||||
This class follows Single Responsibility - it orchestrates the chat flow
|
||||
but delegates data access to repositories and tool execution to executors.
|
||||
|
||||
Dependencies are injected via constructor for better testability.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
tool_executor: ToolExecutor = None,
|
||||
agentic_loop: AgenticLoop = None,
|
||||
provider_factory: LLMProviderFactory = None
|
||||
):
|
||||
"""Initialize ChatService with injected dependencies
|
||||
|
||||
Args:
|
||||
tool_executor: Tool executor instance (creates default if None)
|
||||
agentic_loop: Agentic loop instance (creates default if None)
|
||||
provider_factory: LLM provider factory (uses default if None)
|
||||
"""
|
||||
self._tool_executor = tool_executor or ToolExecutor()
|
||||
self._agentic_loop = agentic_loop or AgenticLoop(self._tool_executor)
|
||||
self._provider_factory = provider_factory or LLMProviderFactory()
|
||||
|
||||
def build_messages(self, conversation, include_system: bool = True) -> List[Dict]:
|
||||
"""Build message list from conversation history."""
|
||||
from luxx.database import SessionLocal
|
||||
from luxx.models import Message
|
||||
|
||||
"""Build message list from conversation history using Repository"""
|
||||
messages = []
|
||||
|
||||
if include_system and conversation.system_prompt:
|
||||
messages.append({"role": "system", "content": conversation.system_prompt})
|
||||
|
||||
db = SessionLocal()
|
||||
try:
|
||||
db_messages = db.query(Message).filter(
|
||||
Message.conversation_id == conversation.id
|
||||
).order_by(Message.created_at).all()
|
||||
|
||||
with UnitOfWork() as uow:
|
||||
db_messages = uow.messages.get_by_conversation(conversation.id)
|
||||
for msg in db_messages:
|
||||
content = MessageBuilder.extract_text(msg.content)
|
||||
messages.append({"role": msg.role, "content": content})
|
||||
finally:
|
||||
db.close()
|
||||
|
||||
return messages
|
||||
|
||||
|
|
@ -168,9 +198,11 @@ class ChatService:
|
|||
"content": json.dumps({"text": user_message, "attachments": []})
|
||||
})
|
||||
|
||||
# Get tools and LLM client
|
||||
# Get tools and LLM client via factory
|
||||
tools = self._get_tools(enabled_tools)
|
||||
llm, provider_max_tokens = get_llm_client(conversation)
|
||||
llm, provider_max_tokens = self._provider_factory.create_client(
|
||||
provider=conversation.provider
|
||||
)
|
||||
model = conversation.model or llm.default_model or "gpt-4"
|
||||
max_tokens = provider_max_tokens or 8192
|
||||
|
||||
|
|
@ -183,10 +215,10 @@ class ChatService:
|
|||
}
|
||||
|
||||
# Stream context
|
||||
ctx = StreamContext()
|
||||
ctx = StreamState()
|
||||
|
||||
# Execute agentic loop
|
||||
async for event in self.agentic_loop.execute(
|
||||
async for event in self._agentic_loop.execute(
|
||||
llm=llm,
|
||||
model=model,
|
||||
messages=messages,
|
||||
|
|
@ -199,22 +231,19 @@ class ChatService:
|
|||
):
|
||||
yield event
|
||||
|
||||
# Save message after successful completion (only if we have content)
|
||||
if ctx._last_message_id and (ctx.full_content or ctx.all_tool_calls):
|
||||
# Save message after successful completion
|
||||
if ctx._last_message_id and ctx.all_steps:
|
||||
self._save_message(
|
||||
conversation.id,
|
||||
ctx._last_message_id,
|
||||
ctx.full_content,
|
||||
ctx.all_tool_calls,
|
||||
ctx.all_tool_results,
|
||||
ctx.all_steps,
|
||||
ctx.get_steps_for_save(),
|
||||
ctx._last_token_count,
|
||||
ctx._last_usage
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Stream error: {e}\n{traceback.format_exc()}")
|
||||
yield _sse_event("error", {"content": str(e)})
|
||||
yield sse_event("error", {"content": str(e)})
|
||||
|
||||
async def non_stream_response(
|
||||
self,
|
||||
|
|
@ -232,7 +261,9 @@ class ChatService:
|
|||
})
|
||||
|
||||
tools = [] if not tools_enabled else None
|
||||
llm, max_tokens = get_llm_client(conversation)
|
||||
llm, max_tokens = self._provider_factory.create_client(
|
||||
provider=conversation.provider
|
||||
)
|
||||
model = conversation.model or llm.default_model or "gpt-4"
|
||||
|
||||
response = await llm.sync_call(
|
||||
|
|
@ -246,9 +277,9 @@ class ChatService:
|
|||
|
||||
return {
|
||||
"success": True,
|
||||
"content": response.content,
|
||||
"tool_calls": response.tool_calls,
|
||||
"usage": response.usage
|
||||
"content": response.get("content", ""),
|
||||
"tool_calls": response.get("tool_calls", []),
|
||||
"usage": response.get("usage", {})
|
||||
}
|
||||
|
||||
except httpx.HTTPStatusError as e:
|
||||
|
|
@ -262,41 +293,40 @@ class ChatService:
|
|||
logger.error(f"Non-stream error: {type(e).__name__}: {e}\n{traceback.format_exc()}")
|
||||
return {"success": False, "error": f"{type(e).__name__}: {str(e)}"}
|
||||
|
||||
def _save_message(self, conversation_id: str, msg_id: str, full_content: str,
|
||||
all_tool_calls: list, all_tool_results: list, all_steps: list,
|
||||
token_count: int = 0, usage: dict = None):
|
||||
"""Save assistant message to database."""
|
||||
from luxx.database import SessionLocal
|
||||
from luxx.models import Message
|
||||
def _save_message(
|
||||
self,
|
||||
conversation_id: str,
|
||||
msg_id: str,
|
||||
all_steps: list,
|
||||
token_count: int = 0,
|
||||
usage: dict = None
|
||||
):
|
||||
"""Save assistant message to database using Repository"""
|
||||
content_json = {"steps": all_steps}
|
||||
|
||||
content_json = {"text": full_content, "steps": all_steps}
|
||||
if all_tool_calls:
|
||||
content_json["tool_calls"] = all_tool_calls
|
||||
|
||||
db = SessionLocal()
|
||||
try:
|
||||
msg = Message(
|
||||
id=msg_id,
|
||||
with UnitOfWork() as uow:
|
||||
uow.messages.create(
|
||||
msg_id=msg_id,
|
||||
conversation_id=conversation_id,
|
||||
role="assistant",
|
||||
content=json.dumps(content_json, ensure_ascii=False),
|
||||
content=content_json,
|
||||
token_count=token_count,
|
||||
usage=json.dumps(usage) if usage else None
|
||||
usage=usage
|
||||
)
|
||||
db.add(msg)
|
||||
db.commit()
|
||||
except Exception:
|
||||
db.rollback()
|
||||
raise
|
||||
finally:
|
||||
db.close()
|
||||
uow.commit()
|
||||
|
||||
|
||||
def _sse_event(event: str, data: dict) -> str:
|
||||
"""Format a Server-Sent Event string."""
|
||||
return f"event: {event}\ndata: {json.dumps(data, ensure_ascii=False)}\n\n"
|
||||
# ============== Factory Function ==============
|
||||
|
||||
def create_chat_service(
|
||||
tool_executor: ToolExecutor = None,
|
||||
agentic_loop: AgenticLoop = None
|
||||
) -> ChatService:
|
||||
"""Factory function to create ChatService instances"""
|
||||
return ChatService(
|
||||
tool_executor=tool_executor,
|
||||
agentic_loop=agentic_loop
|
||||
)
|
||||
|
||||
|
||||
|
||||
def create_chat_service() -> ChatService:
|
||||
"""Factory function to create ChatService instances."""
|
||||
return ChatService()
|
||||
|
|
|
|||
|
|
@ -0,0 +1,31 @@
|
|||
"""SSE (Server-Sent Events) utilities
|
||||
|
||||
This module provides SSE formatting functions to avoid circular imports
|
||||
between stream_context.py and agentic_loop.py.
|
||||
"""
|
||||
import json
|
||||
|
||||
|
||||
def sse_event(event: str, data: dict) -> str:
|
||||
"""Format a Server-Sent Event string
|
||||
|
||||
Args:
|
||||
event: Event type name
|
||||
data: Event data dictionary
|
||||
|
||||
Returns:
|
||||
Formatted SSE string
|
||||
"""
|
||||
return f"event: {event}\ndata: {json.dumps(data, ensure_ascii=False)}\n\n"
|
||||
|
||||
|
||||
def sse_comment(message: str) -> str:
|
||||
"""Format a SSE comment (for keep-alive or debugging)
|
||||
|
||||
Args:
|
||||
message: Comment message
|
||||
|
||||
Returns:
|
||||
Formatted SSE comment string
|
||||
"""
|
||||
return f": {message}\n\n"
|
||||
|
|
@ -2,12 +2,13 @@
|
|||
|
||||
Supports Anthropic Claude API streaming and non-streaming responses.
|
||||
"""
|
||||
|
||||
import json
|
||||
import logging
|
||||
from typing import Dict, List, Any, AsyncGenerator
|
||||
|
||||
from .base import ProviderAdapter
|
||||
from ..llm_response import ParsedDelta, LLMResponse
|
||||
from ..llm_response import ParsedDelta
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
|
@ -233,12 +234,12 @@ class AnthropicAdapter(ProviderAdapter):
|
|||
# Tool use block start
|
||||
tool_index = chunk.get("index", 0)
|
||||
tool_name = block.get("name", "")
|
||||
result.tool_calls = [{
|
||||
result.tool_call = {
|
||||
"index": tool_index,
|
||||
"id": "",
|
||||
"type": "function",
|
||||
"function": {"name": tool_name, "arguments": ""}
|
||||
}]
|
||||
}
|
||||
|
||||
elif block_type == self.SUBTYPE_TEXT:
|
||||
# Text block start - nothing to output yet
|
||||
|
|
@ -262,13 +263,11 @@ class AnthropicAdapter(ProviderAdapter):
|
|||
elif delta_type == self.DELTA_INPUT_JSON:
|
||||
# Tool arguments delta (incremental)
|
||||
partial_json = delta.get("partial_json", "")
|
||||
# For tool calls, we need to update the arguments
|
||||
# This is handled by the consumer (AgenticLoop)
|
||||
if partial_json:
|
||||
result.tool_calls = [{
|
||||
result.tool_call = {
|
||||
"index": 0,
|
||||
"function": {"arguments": partial_json}
|
||||
}]
|
||||
}
|
||||
|
||||
elif chunk_type == self.BLOCK_CONTENT_BLOCK_STOP:
|
||||
# Content block stop
|
||||
|
|
@ -297,7 +296,7 @@ class AnthropicAdapter(ProviderAdapter):
|
|||
if result.has_content() or result.is_complete:
|
||||
yield result
|
||||
|
||||
def parse_response(self, data: Dict[str, Any]) -> LLMResponse:
|
||||
def parse_response(self, data: Dict[str, Any]) -> Dict:
|
||||
"""Parse non-streaming response"""
|
||||
content = data.get("content", [])
|
||||
thinking = ""
|
||||
|
|
@ -321,16 +320,19 @@ class AnthropicAdapter(ProviderAdapter):
|
|||
|
||||
usage = data.get("usage", {})
|
||||
|
||||
return LLMResponse(
|
||||
content=text_content,
|
||||
thinking=thinking,
|
||||
tool_calls=tool_calls,
|
||||
usage={
|
||||
return {
|
||||
"content": text_content,
|
||||
"thinking": thinking,
|
||||
"tool_calls": tool_calls,
|
||||
"usage": {
|
||||
"prompt_tokens": usage.get("input_tokens", 0),
|
||||
"completion_tokens": usage.get("output_tokens", 0),
|
||||
"total_tokens": usage.get("input_tokens", 0) + usage.get("output_tokens", 0)
|
||||
}
|
||||
)
|
||||
}
|
||||
|
||||
def supports_thinking(self) -> bool:
|
||||
return True
|
||||
|
||||
def supports_tools(self) -> bool:
|
||||
return True
|
||||
|
|
|
|||
|
|
@ -1,200 +1,86 @@
|
|||
"""OpenAI Adapter - OpenAI-compatible API adapter
|
||||
"""OpenAI Adapter - OpenAI/DeepSeek/GLM/MiniMax compatible API adapter"""
|
||||
|
||||
Supports OpenAI, DeepSeek, GLM and other OpenAI-compatible APIs.
|
||||
"""
|
||||
import json
|
||||
import logging
|
||||
from typing import Dict, List, Any, AsyncGenerator, Optional
|
||||
from typing import Dict, List, Any, AsyncGenerator
|
||||
|
||||
from .base import ProviderAdapter
|
||||
from ..llm_response import ParsedDelta, LLMResponse
|
||||
from ..llm_response import ParsedDelta
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class OpenAIAdapter(ProviderAdapter):
|
||||
"""OpenAI-compatible API adapter
|
||||
"""OpenAI-compatible API adapter"""
|
||||
|
||||
Pure parsing adapter - no internal state management.
|
||||
Each parse_stream_chunk call returns incremental content.
|
||||
Accumulation is handled by the consumer (AgenticLoop).
|
||||
"""
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
@property
|
||||
def provider_type(self) -> str:
|
||||
return "openai"
|
||||
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
def build_request(
|
||||
self,
|
||||
model: str,
|
||||
messages: List[Dict[str, Any]],
|
||||
tools: List[Dict[str, Any]] = None,
|
||||
**kwargs
|
||||
) -> tuple[Dict[str, Any], Dict[str, str]]:
|
||||
"""Build OpenAI-format request"""
|
||||
def build_request(self, model: str, messages: List[Dict], tools=None, **kwargs) -> tuple:
|
||||
api_key = kwargs.get("api_key", "")
|
||||
|
||||
headers = {
|
||||
"Content-Type": "application/json",
|
||||
"Authorization": f"Bearer {api_key}"
|
||||
}
|
||||
|
||||
body = {
|
||||
"model": model,
|
||||
"messages": messages,
|
||||
"stream": kwargs.get("stream", True)
|
||||
}
|
||||
|
||||
# Optional parameters
|
||||
headers = {"Content-Type": "application/json", "Authorization": f"Bearer {api_key}"}
|
||||
body = {"model": model, "messages": messages, "stream": kwargs.get("stream", True)}
|
||||
if "temperature" in kwargs:
|
||||
body["temperature"] = kwargs["temperature"]
|
||||
if "max_tokens" in kwargs:
|
||||
body["max_tokens"] = kwargs["max_tokens"]
|
||||
if "top_p" in kwargs:
|
||||
body["top_p"] = kwargs["top_p"]
|
||||
if "frequency_penalty" in kwargs:
|
||||
body["frequency_penalty"] = kwargs["frequency_penalty"]
|
||||
if "presence_penalty" in kwargs:
|
||||
body["presence_penalty"] = kwargs["presence_penalty"]
|
||||
if "stop" in kwargs:
|
||||
body["stop"] = kwargs["stop"]
|
||||
if tools:
|
||||
body["tools"] = tools
|
||||
if kwargs.get("thinking_enabled"):
|
||||
body["thinking_enabled"] = True
|
||||
|
||||
body["tool_choice"] = "auto"
|
||||
return body, headers
|
||||
|
||||
def reset(self):
|
||||
"""No-op for pure parsing adapter"""
|
||||
pass
|
||||
|
||||
async def parse_stream_chunk(
|
||||
self,
|
||||
raw_chunk: str
|
||||
) -> AsyncGenerator[ParsedDelta, None]:
|
||||
"""Parse OpenAI-format SSE stream
|
||||
async def parse_stream_chunk(self, raw_chunk: str) -> AsyncGenerator[ParsedDelta, None]:
|
||||
"""Parse OpenAI/MiniMax format. Returns raw content for accumulation."""
|
||||
if not raw_chunk or not raw_chunk.strip():
|
||||
return
|
||||
|
||||
Returns incremental content - no accumulation.
|
||||
"""
|
||||
# Parse SSE line
|
||||
event_type, data_str = self._parse_sse_line(raw_chunk)
|
||||
chunk_str = raw_chunk.strip()
|
||||
if chunk_str.startswith("data: "):
|
||||
chunk_str = chunk_str[6:]
|
||||
elif chunk_str.startswith("data:"):
|
||||
chunk_str = chunk_str[5:]
|
||||
|
||||
if not data_str or data_str == "[DONE]":
|
||||
if data_str == "[DONE]":
|
||||
if chunk_str.strip() == "[DONE]":
|
||||
yield ParsedDelta(is_complete=True)
|
||||
return
|
||||
|
||||
try:
|
||||
chunk = json.loads(data_str)
|
||||
chunk = json.loads(chunk_str)
|
||||
except json.JSONDecodeError:
|
||||
return
|
||||
|
||||
# Handle errors
|
||||
if event_type == "error" or "error" in chunk:
|
||||
yield ParsedDelta()
|
||||
choices = chunk.get("choices", [])
|
||||
if not choices:
|
||||
return
|
||||
|
||||
# Extract usage
|
||||
usage = chunk.get("usage", {})
|
||||
delta = choices[0].get("delta", {})
|
||||
finish_reason = choices[0].get("finish_reason")
|
||||
content = delta.get("content", "")
|
||||
|
||||
# Parse choices
|
||||
for choice in chunk.get("choices", []):
|
||||
delta = choice.get("delta", {})
|
||||
content = delta.get("content") or ""
|
||||
|
||||
# Extract thinking tags if present
|
||||
thinking, clean_text = self._extract_tags(content)
|
||||
|
||||
# Tool calls
|
||||
tool_calls = delta.get("tool_calls", [])
|
||||
|
||||
# Check if this is the final delta
|
||||
is_complete = bool(choice.get("finish_reason"))
|
||||
|
||||
if thinking or clean_text or tool_calls or is_complete or usage:
|
||||
yield ParsedDelta(
|
||||
thinking=thinking,
|
||||
text=clean_text,
|
||||
tool_calls=tool_calls if tool_calls else [],
|
||||
is_complete=is_complete,
|
||||
usage=usage if usage else {}
|
||||
)
|
||||
|
||||
def parse_response(self, data: Dict[str, Any]) -> LLMResponse:
|
||||
"""Parse non-streaming response"""
|
||||
choice = data.get("choices", [{}])[0]
|
||||
message = choice.get("message", {})
|
||||
|
||||
content = message.get("content", "") or ""
|
||||
thinking, clean_content = self._extract_tags(content)
|
||||
if not thinking:
|
||||
thinking = message.get("reasoning_content") or ""
|
||||
|
||||
tool_calls = message.get("tool_calls", [])
|
||||
|
||||
usage = data.get("usage", {})
|
||||
|
||||
return LLMResponse(
|
||||
content=clean_content,
|
||||
thinking=thinking,
|
||||
tool_calls=tool_calls,
|
||||
usage=usage
|
||||
)
|
||||
|
||||
def _parse_sse_line(self, line: str) -> tuple:
|
||||
"""Parse a single SSE line, return (event_type, data)"""
|
||||
if line.startswith("event:"):
|
||||
return line[6:].strip(), None
|
||||
elif line.startswith("data:"):
|
||||
return "", line[5:].strip()
|
||||
return "", None
|
||||
|
||||
def _extract_tags(self, content: str) -> tuple:
|
||||
"""Extract thinking tags and return (thinking, clean_text)
|
||||
|
||||
Handles thinking tags that may be split across chunks:
|
||||
- First </think> in content closes any thinking block
|
||||
- Everything before first </think> is thinking
|
||||
- Everything after first </think> is clean text
|
||||
"""
|
||||
if not content:
|
||||
return "", ""
|
||||
if finish_reason is not None:
|
||||
yield ParsedDelta(is_complete=True)
|
||||
return
|
||||
|
||||
content_lower = content.lower()
|
||||
yield ParsedDelta(content=content)
|
||||
|
||||
# Find first </think> (marks end of thinking block)
|
||||
end_idx = content_lower.find("</think>")
|
||||
def parse_response(self, data: Dict) -> Dict:
|
||||
"""Parse non-streaming response."""
|
||||
choices = data.get("choices", [])
|
||||
if not choices:
|
||||
return {"content": "", "tool_calls": [], "usage": {}}
|
||||
message = choices[0].get("message", {})
|
||||
content = message.get("content", "")
|
||||
tool_calls = message.get("tool_calls", [])
|
||||
usage = data.get("usage", {})
|
||||
return {"content": content, "tool_calls": tool_calls, "usage": usage}
|
||||
|
||||
if end_idx != -1:
|
||||
# Found end tag - split at this point
|
||||
thinking_content = content[:end_idx].strip()
|
||||
# Find if there's also a start tag before this
|
||||
start_idx = content_lower.rfind("<think>", 0, end_idx)
|
||||
|
||||
if start_idx != -1:
|
||||
# There's a complete thinking block
|
||||
thinking = content[start_idx + 7:end_idx]
|
||||
clean = content[end_idx + 9:]
|
||||
else:
|
||||
# No start tag - this is the end of a split thinking block
|
||||
# Everything before </think> was thinking
|
||||
thinking = content[:end_idx]
|
||||
clean = content[end_idx + 9:]
|
||||
|
||||
return thinking, clean
|
||||
|
||||
# No end tag found
|
||||
# Check if there's a start tag
|
||||
start_idx = content_lower.find("<think>")
|
||||
|
||||
if start_idx != -1:
|
||||
# Has start tag but no end - all content after start is thinking
|
||||
thinking = content[start_idx + 7:]
|
||||
return thinking, ""
|
||||
else:
|
||||
# No tags at all - everything is clean
|
||||
return "", content
|
||||
def supports_tools(self) -> bool:
|
||||
return True
|
||||
|
|
|
|||
|
|
@ -17,12 +17,14 @@ Usage:
|
|||
|
||||
# Streaming call
|
||||
async for delta in client.stream_call(model, messages, tools=tools):
|
||||
print(delta.text, delta.thinking, delta.tool_calls)
|
||||
print(delta.text, delta.thinking, delta.tool_call)
|
||||
|
||||
Extending Providers:
|
||||
LLMClient.register_adapter("my_provider", MyAdapter)
|
||||
"""
|
||||
import json
|
||||
import logging
|
||||
import traceback
|
||||
from typing import Dict, List, Any, Optional, AsyncGenerator
|
||||
from typing import Dict, List, Any, Optional, AsyncGenerator, Type
|
||||
|
||||
import httpx
|
||||
|
||||
|
|
@ -32,7 +34,7 @@ from luxx.services.llm_adapters import (
|
|||
OpenAIAdapter,
|
||||
AnthropicAdapter,
|
||||
)
|
||||
from luxx.services.llm_response import ParsedDelta, LLMResponse
|
||||
from luxx.services.llm_response import ParsedDelta
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
|
@ -42,6 +44,9 @@ class LLMClient:
|
|||
|
||||
Uses adapter pattern to support different API formats, auto-detects or manually specifies Provider type.
|
||||
|
||||
Supports plugin registration for extending providers:
|
||||
LLMClient.register_adapter("my_provider", MyAdapter)
|
||||
|
||||
Attributes:
|
||||
api_key: API key
|
||||
api_url: API base URL
|
||||
|
|
@ -50,8 +55,8 @@ class LLMClient:
|
|||
adapter: Current adapter instance
|
||||
"""
|
||||
|
||||
# Provider type to adapter class mapping
|
||||
PROVIDER_ADAPTERS: Dict[str, type] = {
|
||||
# Plugin registry for provider adapters (Open for Extension, Closed for Modification)
|
||||
_adapter_registry: Dict[str, type] = {
|
||||
# OpenAI-compatible formats
|
||||
"openai": OpenAIAdapter,
|
||||
"deepseek": OpenAIAdapter,
|
||||
|
|
@ -63,13 +68,40 @@ class LLMClient:
|
|||
}
|
||||
|
||||
# URL keywords for provider detection
|
||||
PROVIDER_URL_KEYWORDS: Dict[str, List[str]] = {
|
||||
_url_keywords: Dict[str, List[str]] = {
|
||||
"anthropic": ["anthropic", "claude"],
|
||||
"deepseek": ["deepseek"],
|
||||
"glm": ["glm", "zhipu", "chatglm"],
|
||||
"openai": ["openai"],
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def register_adapter(cls, provider_type: str, adapter_class: Type[ProviderAdapter]) -> None:
|
||||
"""Register a new adapter for a provider type
|
||||
|
||||
This follows the Open-Closed Principle (OCP) - open for extension, closed for modification.
|
||||
|
||||
Args:
|
||||
provider_type: Provider type identifier (e.g., "ollama", "groq")
|
||||
adapter_class: Adapter class (must inherit from ProviderAdapter)
|
||||
|
||||
Example:
|
||||
class OllamaAdapter(ProviderAdapter):
|
||||
...
|
||||
|
||||
LLMClient.register_adapter("ollama", OllamaAdapter)
|
||||
"""
|
||||
if not issubclass(adapter_class, ProviderAdapter):
|
||||
raise TypeError(f"{adapter_class.__name__} must inherit from ProviderAdapter")
|
||||
|
||||
cls._adapter_registry[provider_type] = adapter_class
|
||||
logger.info(f"Registered adapter '{adapter_class.__name__}' for provider '{provider_type}'")
|
||||
|
||||
@classmethod
|
||||
def list_providers(cls) -> List[str]:
|
||||
"""List all registered provider types"""
|
||||
return list(cls._adapter_registry.keys())
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
api_key: str = None,
|
||||
|
|
@ -110,7 +142,7 @@ class LLMClient:
|
|||
url = url or self.api_url
|
||||
url_lower = url.lower()
|
||||
|
||||
for provider, keywords in self.PROVIDER_URL_KEYWORDS.items():
|
||||
for provider, keywords in self._url_keywords.items():
|
||||
for keyword in keywords:
|
||||
if keyword in url_lower:
|
||||
logger.debug(f"Detected provider '{provider}' from URL: {url}")
|
||||
|
|
@ -125,7 +157,7 @@ class LLMClient:
|
|||
Returns:
|
||||
ProviderAdapter subclass instance
|
||||
"""
|
||||
adapter_class = self.PROVIDER_ADAPTERS.get(
|
||||
adapter_class = self._adapter_registry.get(
|
||||
self.provider_type,
|
||||
OpenAIAdapter
|
||||
)
|
||||
|
|
@ -160,7 +192,7 @@ class LLMClient:
|
|||
messages: List[Dict[str, Any]],
|
||||
tools: List[Dict[str, Any]] = None,
|
||||
**kwargs
|
||||
) -> LLMResponse:
|
||||
) -> Dict:
|
||||
"""Synchronous call to LLM (non-streaming)
|
||||
|
||||
Args:
|
||||
|
|
@ -170,7 +202,7 @@ class LLMClient:
|
|||
**kwargs: Other parameters (temperature, max_tokens, thinking_enabled, etc.)
|
||||
|
||||
Returns:
|
||||
LLMResponse object
|
||||
Dict with keys: content, thinking, tool_calls, usage
|
||||
"""
|
||||
import asyncio
|
||||
return asyncio.get_event_loop().run_until_complete(
|
||||
|
|
@ -183,7 +215,7 @@ class LLMClient:
|
|||
messages: List[Dict[str, Any]],
|
||||
tools: List[Dict[str, Any]] = None,
|
||||
**kwargs
|
||||
) -> LLMResponse:
|
||||
) -> Dict:
|
||||
"""Internal async sync call"""
|
||||
model = model or self.default_model
|
||||
kwargs["api_key"] = self.api_key
|
||||
|
|
@ -259,8 +291,13 @@ class LLMClient:
|
|||
response.raise_for_status()
|
||||
|
||||
async for line in response.aiter_lines():
|
||||
if line.strip():
|
||||
async for delta in self.adapter.parse_stream_chunk(line):
|
||||
# MiniMax may send multiple SSE events concatenated on one line
|
||||
# Format: data: {...}\ndata: {...}\n
|
||||
parts = line.split("data: ")
|
||||
for part in parts:
|
||||
part = part.strip()
|
||||
if part and part != "[DONE]" and part.startswith("{"):
|
||||
async for delta in self.adapter.parse_stream_chunk("data: " + part):
|
||||
yield delta
|
||||
|
||||
except httpx.HTTPStatusError as e:
|
||||
|
|
|
|||
|
|
@ -1,65 +1,60 @@
|
|||
"""LLM Response - Unified message classes for LLM communication
|
||||
"""LLM Response - Unified message classes for LLM communication"""
|
||||
|
||||
This module provides unified data classes for message passing throughout the LLM pipeline.
|
||||
"""
|
||||
from typing import Dict, Any, List, Optional
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Dict, Optional
|
||||
|
||||
|
||||
class StepType:
|
||||
"""Step type constants"""
|
||||
THINKING = "thinking"
|
||||
TEXT = "text"
|
||||
TOOL_CALL = "tool_call"
|
||||
TOOL_RESULT = "tool_result"
|
||||
|
||||
|
||||
@dataclass
|
||||
class Step:
|
||||
"""Single step - used for storage and transport"""
|
||||
id: str
|
||||
index: int
|
||||
type: str
|
||||
content: str = ""
|
||||
name: str = ""
|
||||
arguments: str = ""
|
||||
id_ref: str = ""
|
||||
success: bool = True
|
||||
|
||||
def to_dict(self) -> Dict:
|
||||
return {
|
||||
"id": self.id,
|
||||
"index": self.index,
|
||||
"type": self.type,
|
||||
"content": self.content,
|
||||
"name": self.name,
|
||||
"arguments": self.arguments,
|
||||
"id_ref": self.id_ref,
|
||||
"success": self.success
|
||||
}
|
||||
|
||||
|
||||
@dataclass
|
||||
class ParsedDelta:
|
||||
"""Streaming response delta
|
||||
|
||||
Represents a single unit of streaming response data.
|
||||
Used for streaming responses where content is accumulated incrementally.
|
||||
|
||||
Attributes:
|
||||
thinking: Accumulated thinking/reasoning content
|
||||
text: Accumulated text content
|
||||
tool_calls: List of tool call requests
|
||||
is_complete: Whether this is the final delta
|
||||
usage: Token usage statistics
|
||||
"""
|
||||
"""LLM streaming response delta"""
|
||||
content: str = ""
|
||||
thinking: str = ""
|
||||
text: str = ""
|
||||
tool_calls: List[Dict] = field(default_factory=list)
|
||||
is_complete: bool = False
|
||||
tool_call: Optional[Dict] = None
|
||||
usage: Dict[str, int] = field(default_factory=dict)
|
||||
is_complete: bool = False
|
||||
|
||||
def has_thinking(self) -> bool:
|
||||
"""Check if there's thinking content"""
|
||||
return bool(self.thinking)
|
||||
|
||||
def has_text(self) -> bool:
|
||||
"""Check if there's text content"""
|
||||
return bool(self.text)
|
||||
|
||||
def has_tool_calls(self) -> bool:
|
||||
"""Check if there are tool calls"""
|
||||
return bool(self.tool_calls)
|
||||
def has_tool_call(self) -> bool:
|
||||
return self.tool_call is not None
|
||||
|
||||
def has_content(self) -> bool:
|
||||
"""Check if there's any content"""
|
||||
return self.has_thinking() or self.has_text() or self.has_tool_calls()
|
||||
|
||||
|
||||
@dataclass
|
||||
class LLMResponse:
|
||||
"""Complete LLM response
|
||||
|
||||
Represents a complete non-streaming response.
|
||||
|
||||
Attributes:
|
||||
content: Final text content
|
||||
thinking: Final thinking content (if any)
|
||||
tool_calls: List of tool calls (if any)
|
||||
usage: Token usage statistics
|
||||
"""
|
||||
content: str = ""
|
||||
thinking: str = ""
|
||||
tool_calls: List[Dict] = field(default_factory=list)
|
||||
usage: Dict[str, int] = field(default=dict)
|
||||
|
||||
def has_tool_calls(self) -> bool:
|
||||
"""Check if there are tool calls"""
|
||||
return bool(self.tool_calls)
|
||||
return bool(self.content) or self.has_thinking() or self.has_text() or self.has_tool_call()
|
||||
|
|
|
|||
|
|
@ -1,37 +0,0 @@
|
|||
"""ProcessResult - Result of processing an SSE line."""
|
||||
|
||||
|
||||
class ProcessResult:
|
||||
"""Result of processing an SSE line.
|
||||
|
||||
Attributes:
|
||||
events: List of SSE event strings to yield
|
||||
has_error: Whether an error occurred
|
||||
error_content: Error message if any
|
||||
has_content: Whether content was received
|
||||
has_tool_calls: Whether tool calls were received
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
self.events: list = []
|
||||
self.has_error: bool = False
|
||||
self.error_content: str = ""
|
||||
self.has_content: bool = False
|
||||
self.has_tool_calls: bool = False
|
||||
|
||||
def add_event(self, event: str):
|
||||
"""Add an event to the result."""
|
||||
self.events.append(event)
|
||||
|
||||
def set_error(self, content: str):
|
||||
"""Set error state."""
|
||||
self.has_error = True
|
||||
self.error_content = content
|
||||
|
||||
def set_content(self):
|
||||
"""Mark that content was received."""
|
||||
self.has_content = True
|
||||
|
||||
def set_tool_calls(self):
|
||||
"""Mark that tool calls were received."""
|
||||
self.has_tool_calls = True
|
||||
|
|
@ -1,51 +1,144 @@
|
|||
"""StreamContext - Manages streaming state transitions during LLM response.
|
||||
"""Stream Context - Manages streaming state and content accumulation
|
||||
|
||||
Tracks steps in order:
|
||||
- thinking: Model reasoning content
|
||||
- text: Model response text
|
||||
- tool_call: Tool invocation request
|
||||
- tool_result: Tool execution result
|
||||
|
||||
Each step has unique id and index for frontend rendering.
|
||||
This module follows the Composition over Inheritance principle initially,
|
||||
but StreamContext inherits from StreamState for simplicity.
|
||||
The rendering logic is delegated to a separate StreamRenderer.
|
||||
"""
|
||||
import json
|
||||
from typing import List, Dict, Optional
|
||||
from dataclasses import dataclass
|
||||
from typing import List, Dict, Any, Optional
|
||||
from enum import Enum
|
||||
|
||||
from luxx.services.events import sse_event
|
||||
|
||||
|
||||
def _sse_event(event: str, data: dict) -> str:
|
||||
"""Format a Server-Sent Event string."""
|
||||
return f"event: {event}\ndata: {json.dumps(data, ensure_ascii=False)}\n\n"
|
||||
class StepType(str, Enum):
|
||||
"""Step type enumeration"""
|
||||
THINKING = "thinking"
|
||||
TEXT = "text"
|
||||
TOOL_CALL = "tool_call"
|
||||
TOOL_RESULT = "tool_result"
|
||||
|
||||
|
||||
class StreamContext:
|
||||
"""Manages streaming state transitions during LLM response."""
|
||||
THINK_START = "<think>"
|
||||
THINK_END = "</think>"
|
||||
|
||||
|
||||
@dataclass
|
||||
class Step:
|
||||
"""Represents a single step in the response process"""
|
||||
id: str
|
||||
index: int
|
||||
type: str
|
||||
content: str = ""
|
||||
name: str = ""
|
||||
arguments: str = ""
|
||||
id_ref: str = ""
|
||||
success: bool = True
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
result = {
|
||||
"id": self.id,
|
||||
"index": self.index,
|
||||
"type": self.type,
|
||||
}
|
||||
if self.content:
|
||||
result["content"] = self.content
|
||||
if self.name:
|
||||
result["name"] = self.name
|
||||
if self.arguments:
|
||||
result["arguments"] = self.arguments
|
||||
if self.id_ref:
|
||||
result["id_ref"] = self.id_ref
|
||||
if self.type == StepType.TOOL_RESULT:
|
||||
result["success"] = self.success
|
||||
return result
|
||||
|
||||
|
||||
class StreamState:
|
||||
"""Pure state management for streaming
|
||||
|
||||
This class maintains all state but delegates rendering to StreamRenderer.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
self.step_index = 0
|
||||
self.current_step_id = None
|
||||
self.current_step_idx = None
|
||||
self.current_step_type = None
|
||||
self.full_content = ""
|
||||
self.full_thinking = ""
|
||||
self.all_steps = []
|
||||
self.all_tool_calls = []
|
||||
self.all_tool_results = []
|
||||
self.tool_calls_list = []
|
||||
self._last_message_id = None
|
||||
self._last_token_count = 0
|
||||
self._last_usage = None
|
||||
self.reset()
|
||||
|
||||
def reset(self):
|
||||
"""Reset state for new iteration."""
|
||||
self.current_step_id = None
|
||||
self.current_step_idx = None
|
||||
self.current_step_type = None
|
||||
"""Reset all state for a new stream"""
|
||||
self.step_index = 0
|
||||
self.current_step_id: Optional[str] = None
|
||||
self.current_step_idx: Optional[int] = None
|
||||
self.current_step_type: Optional[str] = None
|
||||
self.full_content = ""
|
||||
self.full_thinking = ""
|
||||
self.tool_calls_list = []
|
||||
self.all_steps: List[Step] = []
|
||||
self.all_tool_results: List[Dict] = []
|
||||
self.tool_calls_list: List[Dict] = []
|
||||
self._last_message_id: Optional[str] = None
|
||||
self._last_token_count = 0
|
||||
self._last_usage: Optional[Dict] = None
|
||||
self._in_thinking = False
|
||||
self._thinking_buf = ""
|
||||
self._text_buf = ""
|
||||
|
||||
def process_content(self, content: str) -> Dict:
|
||||
"""Process raw content, handling thinking tags."""
|
||||
if not content:
|
||||
return {"thinking": "", "text": "", "should_emit": False, "thinking_only": False}
|
||||
|
||||
thinking = ""
|
||||
text = ""
|
||||
should_emit = False
|
||||
thinking_only = False
|
||||
|
||||
if THINK_START in content and not self._in_thinking:
|
||||
self._in_thinking = True
|
||||
idx = content.find(THINK_START) + len(THINK_START)
|
||||
content = content[idx:]
|
||||
|
||||
if THINK_END in content:
|
||||
idx = content.find(THINK_END)
|
||||
thinking_content = content[:idx]
|
||||
self._thinking_buf += thinking_content
|
||||
content = content[idx + len(THINK_END):]
|
||||
|
||||
# Remove all remaining thinking tags from text (MiniMax format)
|
||||
while THINK_END in content:
|
||||
second_idx = content.find(THINK_END)
|
||||
text_content = content[:second_idx]
|
||||
self._text_buf += text_content
|
||||
content = content[second_idx + len(THINK_END):]
|
||||
|
||||
self._in_thinking = False
|
||||
should_emit = True
|
||||
thinking_only = not bool(self._text_buf)
|
||||
|
||||
if self._in_thinking:
|
||||
self._thinking_buf += content
|
||||
else:
|
||||
self._text_buf += content
|
||||
|
||||
if should_emit:
|
||||
thinking = self._thinking_buf
|
||||
text = self._text_buf
|
||||
|
||||
return {
|
||||
"thinking": thinking,
|
||||
"text": text,
|
||||
"should_emit": should_emit,
|
||||
"thinking_only": thinking_only
|
||||
}
|
||||
|
||||
def flush(self) -> tuple:
|
||||
"""Flush remaining buffers and return content"""
|
||||
thinking = self._thinking_buf
|
||||
text = self._text_buf
|
||||
self._thinking_buf = ""
|
||||
self._text_buf = ""
|
||||
return thinking, text
|
||||
|
||||
def start_step(self, step_type: str) -> str:
|
||||
"""Start a new step with unique ID."""
|
||||
"""Start a new step and return its ID"""
|
||||
self.current_step_idx = self.step_index
|
||||
self.current_step_id = f"step-{self.step_index}"
|
||||
self.current_step_type = step_type
|
||||
|
|
@ -53,20 +146,20 @@ class StreamContext:
|
|||
return self.current_step_id
|
||||
|
||||
def finalize_step(self):
|
||||
"""Save current step to all_steps."""
|
||||
"""Finalize the current step and add to all_steps"""
|
||||
if self.current_step_id is None:
|
||||
return
|
||||
|
||||
content = self.full_content if self.current_step_type == "text" else self.full_thinking
|
||||
self.all_steps.append({
|
||||
"id": self.current_step_id,
|
||||
"index": self.current_step_idx,
|
||||
"type": self.current_step_type,
|
||||
"content": content
|
||||
})
|
||||
content = self.full_content if self.current_step_type == StepType.TEXT else self.full_thinking
|
||||
step = Step(
|
||||
id=self.current_step_id,
|
||||
index=self.current_step_idx,
|
||||
type=self.current_step_type,
|
||||
content=content
|
||||
)
|
||||
self.all_steps.append(step)
|
||||
|
||||
def accumulate_tool_call(self, tc_delta: Dict):
|
||||
"""Accumulate tool call delta."""
|
||||
"""Accumulate tool call delta"""
|
||||
idx = tc_delta.get("index", 0)
|
||||
if idx >= len(self.tool_calls_list):
|
||||
self.tool_calls_list.append({
|
||||
|
|
@ -74,41 +167,63 @@ class StreamContext:
|
|||
"type": "function",
|
||||
"function": {"name": "", "arguments": ""}
|
||||
})
|
||||
|
||||
func = tc_delta.get("function", {})
|
||||
if func.get("name"):
|
||||
self.tool_calls_list[idx]["function"]["name"] += func["name"]
|
||||
if func.get("arguments"):
|
||||
self.tool_calls_list[idx]["function"]["arguments"] += func["arguments"]
|
||||
|
||||
def emit_tool_calls(self) -> List[str]:
|
||||
"""Emit tool call steps, return SSE events."""
|
||||
def add_tool_result(self, result: Dict):
|
||||
"""Add a tool result to history"""
|
||||
self.all_tool_results.append({
|
||||
"role": "tool",
|
||||
"tool_call_id": result.get("tool_call_id", ""),
|
||||
"content": result.get("content", "")
|
||||
})
|
||||
|
||||
def set_completion(self, msg_id: str, token_count: int, usage: dict):
|
||||
"""Set completion metadata"""
|
||||
self._last_message_id = msg_id
|
||||
self._last_token_count = token_count
|
||||
self._last_usage = usage
|
||||
|
||||
def get_steps_for_save(self) -> List[Dict]:
|
||||
"""Get all steps as dictionaries"""
|
||||
return [step.to_dict() for step in self.all_steps]
|
||||
|
||||
|
||||
class StreamRenderer:
|
||||
"""Renders stream state to SSE events"""
|
||||
|
||||
@staticmethod
|
||||
def render_tool_calls(state: StreamState) -> List[str]:
|
||||
"""Render tool calls as SSE events"""
|
||||
events = []
|
||||
for tc in self.tool_calls_list:
|
||||
step_id = f"step-{self.step_index}"
|
||||
self.step_index += 1
|
||||
|
||||
step = {
|
||||
"id": step_id,
|
||||
"index": self.step_index - 1,
|
||||
"type": "tool_call",
|
||||
"id_ref": tc.get("id", ""),
|
||||
"name": tc["function"]["name"],
|
||||
"arguments": tc["function"]["arguments"]
|
||||
}
|
||||
self.all_steps.append(step)
|
||||
self.all_tool_calls.append(tc)
|
||||
events.append(_sse_event("process_step", {"step": step}))
|
||||
|
||||
for tc in state.tool_calls_list:
|
||||
step_id = f"step-{state.step_index}"
|
||||
state.step_index += 1
|
||||
step = Step(
|
||||
id=step_id,
|
||||
index=state.step_index - 1,
|
||||
type=StepType.TOOL_CALL,
|
||||
name=tc["function"]["name"],
|
||||
arguments=tc["function"]["arguments"],
|
||||
id_ref=tc.get("id", "")
|
||||
)
|
||||
state.all_steps.append(step)
|
||||
events.append(sse_event("process_step", {"step": step.to_dict()}))
|
||||
return events
|
||||
|
||||
def emit_tool_result(self, result: Dict, ref_step_id: str) -> tuple:
|
||||
"""Emit tool result step, return (step, event)."""
|
||||
step_id = f"step-{self.step_index}"
|
||||
self.step_index += 1
|
||||
@staticmethod
|
||||
def render_tool_result(state: StreamState, result: Dict, ref_step_id: str) -> tuple:
|
||||
"""Render a tool result as SSE event"""
|
||||
import json
|
||||
|
||||
step_id = f"step-{state.step_index}"
|
||||
state.step_index += 1
|
||||
content = result.get("content", "")
|
||||
success = True
|
||||
|
||||
try:
|
||||
parsed = json.loads(content)
|
||||
if isinstance(parsed, dict):
|
||||
|
|
@ -116,32 +231,49 @@ class StreamContext:
|
|||
except (json.JSONDecodeError, TypeError):
|
||||
pass
|
||||
|
||||
step = {
|
||||
"id": step_id,
|
||||
"index": self.step_index - 1,
|
||||
"type": "tool_result",
|
||||
"id_ref": ref_step_id,
|
||||
"name": result.get("name", ""),
|
||||
"content": content,
|
||||
"success": success
|
||||
}
|
||||
self.all_steps.append(step)
|
||||
self.all_tool_results.append({
|
||||
"role": "tool",
|
||||
"tool_call_id": result.get("tool_call_id", ""),
|
||||
"content": content
|
||||
})
|
||||
step = Step(
|
||||
id=step_id,
|
||||
index=state.step_index - 1,
|
||||
type=StepType.TOOL_RESULT,
|
||||
name=result.get("name", ""),
|
||||
content=content,
|
||||
id_ref=ref_step_id,
|
||||
success=success
|
||||
)
|
||||
state.all_steps.append(step)
|
||||
state.add_tool_result(result)
|
||||
|
||||
return step, _sse_event("process_step", {"step": step})
|
||||
return step, sse_event("process_step", {"step": step.to_dict()})
|
||||
|
||||
def set_completion(self, msg_id: str, token_count: int, usage: dict):
|
||||
"""Set completion info for saving."""
|
||||
self._last_message_id = msg_id
|
||||
self._last_token_count = token_count
|
||||
self._last_usage = usage
|
||||
@staticmethod
|
||||
def render_thinking(state: StreamState) -> str:
|
||||
"""Render thinking content as SSE event"""
|
||||
step = Step(
|
||||
id=state.current_step_id,
|
||||
index=state.current_step_idx,
|
||||
type=StepType.THINKING,
|
||||
content=state.full_thinking
|
||||
)
|
||||
return sse_event("process_step", {"step": step.to_dict()})
|
||||
|
||||
def reset_completion(self):
|
||||
"""Reset completion info."""
|
||||
self._last_message_id = None
|
||||
self._last_token_count = 0
|
||||
self._last_usage = None
|
||||
@staticmethod
|
||||
def render_text(state: StreamState) -> str:
|
||||
"""Render text content as SSE event"""
|
||||
step = Step(
|
||||
id=state.current_step_id,
|
||||
index=state.current_step_idx,
|
||||
type=StepType.TEXT,
|
||||
content=state.full_content
|
||||
)
|
||||
return sse_event("process_step", {"step": step.to_dict()})
|
||||
|
||||
@staticmethod
|
||||
def render_error(error_msg: str) -> str:
|
||||
"""Render error event"""
|
||||
return sse_event("error", {"content": error_msg})
|
||||
|
||||
|
||||
# Convenience function for backward compatibility
|
||||
def render_error(error_msg: str) -> str:
|
||||
"""Render error event"""
|
||||
return sse_event("error", {"content": error_msg})
|
||||
|
|
|
|||
|
|
@ -1,61 +1,121 @@
|
|||
"""Tool executor"""
|
||||
"""Tool executor with caching and parallel execution support
|
||||
|
||||
This module follows the Single Responsibility Principle:
|
||||
- ToolExecutor: Tool execution logic
|
||||
- CallHistory: Call history management
|
||||
- CacheManager: Caching logic
|
||||
"""
|
||||
import json
|
||||
import time
|
||||
from typing import List, Dict, Any, Optional
|
||||
from threading import Lock
|
||||
|
||||
from luxx.tools.core import registry, ToolResult, ToolContext
|
||||
from luxx.tools.core import registry, ToolContext
|
||||
|
||||
|
||||
class ToolExecutor:
|
||||
"""Tool executor with caching and parallel execution support"""
|
||||
class CacheManager:
|
||||
"""Manages tool result caching"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
enable_cache: bool = True,
|
||||
cache_ttl: int = 300, # 5 minutes
|
||||
max_workers: int = 4
|
||||
):
|
||||
def __init__(self, enable_cache: bool = True, cache_ttl: int = 300):
|
||||
self.enable_cache = enable_cache
|
||||
self.cache_ttl = cache_ttl
|
||||
self.max_workers = max_workers
|
||||
self._cache: Dict[str, tuple] = {} # key: (result, timestamp)
|
||||
self._call_history: List[Dict[str, Any]] = []
|
||||
self._lock = Lock()
|
||||
|
||||
def _make_cache_key(self, name: str, args: dict) -> str:
|
||||
def make_key(self, name: str, args: dict, workspace: str = None) -> str:
|
||||
"""Generate cache key"""
|
||||
args_str = json.dumps(args, sort_keys=True, ensure_ascii=False)
|
||||
return f"{name}:{args_str}"
|
||||
key = f"{name}:{args_str}"
|
||||
if workspace:
|
||||
key = f"{key}:{workspace}"
|
||||
return key
|
||||
|
||||
def _is_cache_valid(self, cache_key: str) -> bool:
|
||||
def is_valid(self, cache_key: str) -> bool:
|
||||
"""Check if cache is valid"""
|
||||
if cache_key not in self._cache:
|
||||
return False
|
||||
_, timestamp = self._cache[cache_key]
|
||||
return time.time() - timestamp < self.cache_ttl
|
||||
|
||||
def _get_cached(self, cache_key: str) -> Optional[Dict]:
|
||||
def get(self, cache_key: str) -> Optional[Dict]:
|
||||
"""Get cached result"""
|
||||
if self.enable_cache and self._is_cache_valid(cache_key):
|
||||
if not self.enable_cache:
|
||||
return None
|
||||
if self.is_valid(cache_key):
|
||||
return self._cache[cache_key][0]
|
||||
return None
|
||||
|
||||
def _set_cached(self, cache_key: str, result: Dict) -> None:
|
||||
def set(self, cache_key: str, result: Dict) -> None:
|
||||
"""Set cache"""
|
||||
if self.enable_cache:
|
||||
if not self.enable_cache:
|
||||
return
|
||||
with self._lock:
|
||||
self._cache[cache_key] = (result, time.time())
|
||||
|
||||
def _record_call(self, name: str, args: dict, result: Dict) -> None:
|
||||
"""Record call history"""
|
||||
self._call_history.append({
|
||||
def clear(self) -> None:
|
||||
"""Clear all cache"""
|
||||
with self._lock:
|
||||
self._cache.clear()
|
||||
|
||||
def size(self) -> int:
|
||||
"""Get cache size"""
|
||||
return len(self._cache)
|
||||
|
||||
|
||||
class CallHistory:
|
||||
"""Manages tool call history"""
|
||||
|
||||
MAX_HISTORY_SIZE = 1000
|
||||
|
||||
def __init__(self):
|
||||
self._history: List[Dict[str, Any]] = []
|
||||
self._lock = Lock()
|
||||
|
||||
def record(self, name: str, args: dict, result: Dict) -> None:
|
||||
"""Record a tool call"""
|
||||
entry = {
|
||||
"name": name,
|
||||
"args": args,
|
||||
"result": result,
|
||||
"timestamp": time.time()
|
||||
})
|
||||
|
||||
}
|
||||
with self._lock:
|
||||
self._history.append(entry)
|
||||
# Limit history size
|
||||
if len(self._call_history) > 1000:
|
||||
self._call_history = self._call_history[-1000:]
|
||||
if len(self._history) > self.MAX_HISTORY_SIZE:
|
||||
self._history = self._history[-self.MAX_HISTORY_SIZE:]
|
||||
|
||||
def get(self, limit: int = 100) -> List[Dict[str, Any]]:
|
||||
"""Get recent call history"""
|
||||
with self._lock:
|
||||
return self._history[-limit:].copy()
|
||||
|
||||
def clear(self) -> None:
|
||||
"""Clear all history"""
|
||||
with self._lock:
|
||||
self._history.clear()
|
||||
|
||||
def size(self) -> int:
|
||||
"""Get history size"""
|
||||
return len(self._history)
|
||||
|
||||
|
||||
class ToolExecutor:
|
||||
"""Tool executor with caching and parallel execution support
|
||||
|
||||
This class delegates caching and history to specialized classes,
|
||||
following the Single Responsibility Principle.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
enable_cache: bool = True,
|
||||
cache_ttl: int = 300,
|
||||
max_workers: int = 4
|
||||
):
|
||||
self.cache = CacheManager(enable_cache=enable_cache, cache_ttl=cache_ttl)
|
||||
self.history = CallHistory()
|
||||
self.max_workers = max_workers
|
||||
|
||||
def process_tool_calls(
|
||||
self,
|
||||
|
|
@ -63,16 +123,8 @@ class ToolExecutor:
|
|||
context: Dict[str, Any]
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""Process tool calls sequentially"""
|
||||
# Build ToolContext from context dict (includes user_permission_level)
|
||||
tool_ctx = ToolContext(
|
||||
workspace=context.get("workspace"),
|
||||
user_id=context.get("user_id"),
|
||||
username=context.get("username"),
|
||||
extra={
|
||||
"user_permission_level": context.get("user_permission_level", 1),
|
||||
**(context.get("extra", {}))
|
||||
}
|
||||
)
|
||||
# Build ToolContext from context dict
|
||||
tool_ctx = self._build_tool_context(context)
|
||||
|
||||
results = []
|
||||
|
||||
|
|
@ -81,26 +133,21 @@ class ToolExecutor:
|
|||
name = call.get("function", {}).get("name", "")
|
||||
|
||||
# Parse JSON arguments
|
||||
try:
|
||||
args = json.loads(call.get("function", {}).get("arguments", "{}"))
|
||||
except json.JSONDecodeError:
|
||||
args = {}
|
||||
args = self._parse_arguments(call)
|
||||
|
||||
# Check cache (include context in cache key for file operations)
|
||||
cache_key = self._make_cache_key(name, args)
|
||||
if tool_ctx.workspace:
|
||||
cache_key = f"{cache_key}:{tool_ctx.workspace}"
|
||||
cached = self._get_cached(cache_key)
|
||||
# Check cache
|
||||
cache_key = self.cache.make_key(name, args, tool_ctx.workspace)
|
||||
cached = self.cache.get(cache_key)
|
||||
|
||||
if cached is not None:
|
||||
result = cached
|
||||
else:
|
||||
# Execute tool with context
|
||||
result = registry.execute(name, args, context=tool_ctx)
|
||||
self._set_cached(cache_key, result)
|
||||
self.cache.set(cache_key, result)
|
||||
|
||||
# Record call
|
||||
self._record_call(name, args, result)
|
||||
self.history.record(name, args, result)
|
||||
|
||||
# Create result message
|
||||
results.append(self._create_tool_result(call_id, name, result))
|
||||
|
|
@ -116,16 +163,7 @@ class ToolExecutor:
|
|||
if len(tool_calls) <= 1:
|
||||
return self.process_tool_calls(tool_calls, context)
|
||||
|
||||
# Build ToolContext from context dict (includes user_permission_level)
|
||||
tool_ctx = ToolContext(
|
||||
workspace=context.get("workspace"),
|
||||
user_id=context.get("user_id"),
|
||||
username=context.get("username"),
|
||||
extra={
|
||||
"user_permission_level": context.get("user_permission_level", 1),
|
||||
**(context.get("extra", {}))
|
||||
}
|
||||
)
|
||||
tool_ctx = self._build_tool_context(context)
|
||||
|
||||
try:
|
||||
from concurrent.futures import ThreadPoolExecutor, as_completed
|
||||
|
|
@ -136,45 +174,61 @@ class ToolExecutor:
|
|||
for call in tool_calls:
|
||||
call_id = call.get("id", "")
|
||||
name = call.get("function", {}).get("name", "")
|
||||
|
||||
# Parse all arguments
|
||||
try:
|
||||
args = json.loads(call.get("function", {}).get("arguments", "{}"))
|
||||
except json.JSONDecodeError:
|
||||
args = {}
|
||||
args = self._parse_arguments(call)
|
||||
|
||||
# Check cache
|
||||
cache_key = self._make_cache_key(name, args)
|
||||
if tool_ctx.workspace:
|
||||
cache_key = f"{cache_key}:{tool_ctx.workspace}"
|
||||
cached = self._get_cached(cache_key)
|
||||
cache_key = self.cache.make_key(name, args, tool_ctx.workspace)
|
||||
cached = self.cache.get(cache_key)
|
||||
|
||||
if cached is not None:
|
||||
futures[call_id] = (name, args, cached)
|
||||
else:
|
||||
# Submit task with context
|
||||
future = executor.submit(registry.execute, name, args, context=tool_ctx)
|
||||
# Submit task
|
||||
future = executor.submit(
|
||||
registry.execute, name, args, context=tool_ctx
|
||||
)
|
||||
futures[future] = (call_id, name, args, cache_key)
|
||||
|
||||
results = []
|
||||
|
||||
for future in as_completed(futures.keys()):
|
||||
if future in futures:
|
||||
item = futures[future]
|
||||
if len(item) == 3:
|
||||
call_id, name, args = item
|
||||
cache_key = self._make_cache_key(name, args)
|
||||
cache_key = self.cache.make_key(name, args, tool_ctx.workspace)
|
||||
result = item[2]
|
||||
else:
|
||||
call_id, name, args, cache_key = item
|
||||
result = future.result()
|
||||
self._set_cached(cache_key, result)
|
||||
self._record_call(name, args, result)
|
||||
self.cache.set(cache_key, result)
|
||||
|
||||
self.history.record(name, args, result)
|
||||
results.append(self._create_tool_result(call_id, name, result))
|
||||
|
||||
return results
|
||||
|
||||
except ImportError:
|
||||
return self.process_tool_calls(tool_calls, context)
|
||||
|
||||
def _build_tool_context(self, context: Dict[str, Any]) -> ToolContext:
|
||||
"""Build ToolContext from context dict"""
|
||||
return ToolContext(
|
||||
workspace=context.get("workspace"),
|
||||
user_id=context.get("user_id"),
|
||||
username=context.get("username"),
|
||||
extra={
|
||||
"user_permission_level": context.get("user_permission_level", 1),
|
||||
**(context.get("extra", {}))
|
||||
}
|
||||
)
|
||||
|
||||
def _parse_arguments(self, call: Dict[str, Any]) -> Dict:
|
||||
"""Parse JSON arguments from tool call"""
|
||||
try:
|
||||
return json.loads(call.get("function", {}).get("arguments", "{}"))
|
||||
except json.JSONDecodeError:
|
||||
return {}
|
||||
|
||||
def _create_tool_result(self, call_id: str, name: str, result: Dict) -> Dict[str, Any]:
|
||||
"""Create tool result message"""
|
||||
return {
|
||||
|
|
@ -184,19 +238,10 @@ class ToolExecutor:
|
|||
"content": json.dumps(result, ensure_ascii=False)
|
||||
}
|
||||
|
||||
def _create_error_result(self, call_id: str, name: str, error: str) -> Dict[str, Any]:
|
||||
"""Create error result message"""
|
||||
return {
|
||||
"tool_call_id": call_id,
|
||||
"role": "tool",
|
||||
"name": name,
|
||||
"content": json.dumps({"success": False, "error": error}, ensure_ascii=False)
|
||||
}
|
||||
|
||||
def clear_cache(self) -> None:
|
||||
"""Clear all cache"""
|
||||
self._cache.clear()
|
||||
self.cache.clear()
|
||||
|
||||
def get_history(self, limit: int = 100) -> List[Dict[str, Any]]:
|
||||
"""Get call history"""
|
||||
return self._call_history[-limit:]
|
||||
return self.history.get(limit)
|
||||
|
|
|
|||
|
|
@ -31,3 +31,11 @@ dev = ["pytest>=8.0.0", "pytest-asyncio>=0.23.0", "pytest-cov>=4.1.0", "black>=2
|
|||
|
||||
[tool.setuptools]
|
||||
packages = ["luxx"]
|
||||
|
||||
[tool.pytest.ini_options]
|
||||
testpaths = ["tests"]
|
||||
python_files = ["test_*.py"]
|
||||
python_classes = ["Test*"]
|
||||
python_functions = ["test_*"]
|
||||
asyncio_mode = "auto"
|
||||
addopts = "-v --tb=short"
|
||||
|
|
|
|||
|
|
@ -0,0 +1 @@
|
|||
"""Test suite for Luxx project"""
|
||||
|
|
@ -0,0 +1,32 @@
|
|||
"""Pytest configuration and fixtures"""
|
||||
import os
|
||||
import sys
|
||||
import pytest
|
||||
from pathlib import Path
|
||||
|
||||
# Add project root to path
|
||||
project_root = Path(__file__).parent.parent
|
||||
sys.path.insert(0, str(project_root))
|
||||
|
||||
# Set test environment variables
|
||||
os.environ.setdefault("APP_SECRET_KEY", "test-secret-key-for-testing")
|
||||
os.environ.setdefault("DEEPSEEK_API_KEY", "test-api-key")
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def sample_workspace(tmp_path):
|
||||
"""Create a temporary workspace for testing"""
|
||||
workspace = tmp_path / "test_workspace"
|
||||
workspace.mkdir()
|
||||
return workspace
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def sample_user_context():
|
||||
"""Sample user context for tool testing"""
|
||||
return {
|
||||
"user_id": 1,
|
||||
"username": "test_user",
|
||||
"workspace": "/tmp/test_workspace",
|
||||
"user_permission_level": 3
|
||||
}
|
||||
|
|
@ -0,0 +1,333 @@
|
|||
"""Tests for LLM response parsing - MiniMax/OpenAI format"""
|
||||
import json
|
||||
import pytest
|
||||
from luxx.services.llm_response import ParsedDelta
|
||||
|
||||
|
||||
class TestMiniMaxParsing:
|
||||
"""Tests for MiniMax/OpenAI streaming format parsing"""
|
||||
|
||||
def test_parse_text_chunk(self):
|
||||
"""Parse text content chunk"""
|
||||
chunk_str = 'data: {"id":"msg_001","choices":[{"index":0,"delta":{"content":"Hello","role":"assistant"}}]}'
|
||||
chunk = json.loads(chunk_str[6:]) # Remove "data: "
|
||||
|
||||
delta = chunk["choices"][0]["delta"]
|
||||
assert delta["content"] == "Hello"
|
||||
|
||||
def test_parse_text_accumulation(self):
|
||||
"""Multiple chunks should accumulate to full text"""
|
||||
chunks = [
|
||||
'{"id":"msg_001","choices":[{"index":0,"delta":{"content":"你好","role":"assistant"}}]}',
|
||||
'{"id":"msg_001","choices":[{"index":0,"delta":{"content":"!","role":"assistant"}}]}',
|
||||
'{"id":"msg_001","choices":[{"index":0,"delta":{"content":"有什么","role":"assistant"}}]}',
|
||||
'{"id":"msg_001","choices":[{"index":0,"delta":{"content":"可以帮助你的吗?","role":"assistant"}}]}',
|
||||
]
|
||||
|
||||
full_text = ""
|
||||
for c in chunks:
|
||||
chunk = json.loads(c)
|
||||
content = chunk["choices"][0]["delta"].get("content", "")
|
||||
full_text += content
|
||||
|
||||
assert full_text == "你好!有什么可以帮助你的吗?"
|
||||
|
||||
def test_parse_finish_chunk(self):
|
||||
"""Parse finish_reason chunk"""
|
||||
chunk_str = 'data: {"id":"msg_001","choices":[{"finish_reason":"stop","index":0,"delta":{"role":"assistant"}}]}'
|
||||
chunk = json.loads(chunk_str[6:])
|
||||
|
||||
finish_reason = chunk["choices"][0].get("finish_reason")
|
||||
assert finish_reason == "stop"
|
||||
|
||||
def test_parse_usage_chunk(self):
|
||||
"""Parse usage information"""
|
||||
chunk_str = '{"id":"msg_001","usage":{"prompt_tokens":100,"completion_tokens":50,"total_tokens":150}}'
|
||||
chunk = json.loads(chunk_str)
|
||||
|
||||
usage = chunk.get("usage", {})
|
||||
assert usage["prompt_tokens"] == 100
|
||||
assert usage["completion_tokens"] == 50
|
||||
assert usage["total_tokens"] == 150
|
||||
|
||||
def test_parse_empty_content_chunk(self):
|
||||
"""Parse chunk with empty content (just role)"""
|
||||
chunk_str = 'data: {"id":"msg_001","choices":[{"index":0,"delta":{"content":"","role":"assistant"}}]}'
|
||||
chunk = json.loads(chunk_str[6:])
|
||||
|
||||
delta = chunk["choices"][0]["delta"]
|
||||
content = delta.get("content", "")
|
||||
assert content == ""
|
||||
|
||||
def test_parse_done_marker(self):
|
||||
"""Parse [DONE] marker"""
|
||||
chunk_str = "data: [DONE]"
|
||||
assert chunk_str.strip().startswith("data: [DONE]")
|
||||
|
||||
|
||||
class TestOpenAIAdapter:
|
||||
"""Tests for OpenAI adapter parsing logic"""
|
||||
|
||||
@pytest.fixture
|
||||
def adapter(self):
|
||||
from luxx.services.llm_adapters.openai_adapter import OpenAIAdapter
|
||||
return OpenAIAdapter()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_parse_stream_text(self, adapter):
|
||||
"""Should parse text content"""
|
||||
chunk = 'data: {"id":"msg_001","choices":[{"index":0,"delta":{"content":"Hello"}}]}'
|
||||
deltas = [d async for d in adapter.parse_stream_chunk(chunk)]
|
||||
|
||||
assert len(deltas) == 1
|
||||
assert deltas[0].text == "Hello"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_parse_stream_finish(self, adapter):
|
||||
"""Should detect completion"""
|
||||
chunk = 'data: {"id":"msg_001","choices":[{"finish_reason":"stop","index":0,"delta":{}}]}'
|
||||
deltas = [d async for d in adapter.parse_stream_chunk(chunk)]
|
||||
|
||||
assert len(deltas) == 1
|
||||
assert deltas[0].is_complete is True
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_parse_stream_empty_content(self, adapter):
|
||||
"""Should skip empty content chunks"""
|
||||
chunk = 'data: {"id":"msg_001","choices":[{"index":0,"delta":{"role":"assistant"}}]}'
|
||||
deltas = [d async for d in adapter.parse_stream_chunk(chunk)]
|
||||
|
||||
# Empty content without finish_reason should be skipped
|
||||
assert len(deltas) == 0
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_parse_stream_done_marker(self, adapter):
|
||||
"""Should handle [DONE] marker"""
|
||||
chunk = "data: [DONE]"
|
||||
deltas = [d async for d in adapter.parse_stream_chunk(chunk)]
|
||||
|
||||
assert len(deltas) == 1
|
||||
assert deltas[0].is_complete is True
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_parse_stream_invalid_json(self, adapter):
|
||||
"""Should handle invalid JSON gracefully"""
|
||||
chunk = "not valid json"
|
||||
deltas = [d async for d in adapter.parse_stream_chunk(chunk)]
|
||||
|
||||
assert len(deltas) == 0
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_parse_stream_empty_chunk(self, adapter):
|
||||
"""Should handle empty chunk"""
|
||||
deltas = [d async for d in adapter.parse_stream_chunk("")]
|
||||
assert len(deltas) == 0
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_parse_stream_whitespace_chunk(self, adapter):
|
||||
"""Should handle whitespace-only chunk"""
|
||||
deltas = [d async for d in adapter.parse_stream_chunk(" \n")]
|
||||
assert len(deltas) == 0
|
||||
|
||||
|
||||
class TestBuildRequest:
|
||||
"""Tests for request building"""
|
||||
|
||||
@pytest.fixture
|
||||
def adapter(self):
|
||||
from luxx.services.llm_adapters.openai_adapter import OpenAIAdapter
|
||||
return OpenAIAdapter()
|
||||
|
||||
def test_build_request_basic(self, adapter):
|
||||
"""Should build basic request"""
|
||||
body, headers = adapter.build_request(
|
||||
model="MiniMax-M2.5",
|
||||
messages=[{"role": "user", "content": "Hello"}]
|
||||
)
|
||||
|
||||
assert body["model"] == "MiniMax-M2.5"
|
||||
assert body["messages"] == [{"role": "user", "content": "Hello"}]
|
||||
assert body["stream"] is True
|
||||
|
||||
def test_build_request_with_tools(self, adapter):
|
||||
"""Should include tools in request"""
|
||||
tools = [{"type": "function", "function": {"name": "test", "parameters": {}}}]
|
||||
body, _ = adapter.build_request(
|
||||
model="MiniMax-M2.5",
|
||||
messages=[],
|
||||
tools=tools
|
||||
)
|
||||
|
||||
assert "tools" in body
|
||||
assert body["tool_choice"] == "auto"
|
||||
|
||||
def test_build_request_with_temperature(self, adapter):
|
||||
"""Should include temperature"""
|
||||
body, _ = adapter.build_request(
|
||||
model="MiniMax-M2.5",
|
||||
messages=[],
|
||||
temperature=0.7
|
||||
)
|
||||
|
||||
assert body["temperature"] == 0.7
|
||||
|
||||
def test_build_request_with_max_tokens(self, adapter):
|
||||
"""Should include max_tokens"""
|
||||
body, _ = adapter.build_request(
|
||||
model="MiniMax-M2.5",
|
||||
messages=[],
|
||||
max_tokens=1000
|
||||
)
|
||||
|
||||
assert body["max_tokens"] == 1000
|
||||
|
||||
|
||||
class TestParsedDelta:
|
||||
"""Tests for ParsedDelta dataclass"""
|
||||
|
||||
def test_default_values(self):
|
||||
"""Should have correct defaults"""
|
||||
delta = ParsedDelta()
|
||||
assert delta.text == ""
|
||||
assert delta.thinking == ""
|
||||
assert delta.tool_call is None
|
||||
assert delta.usage == {}
|
||||
assert delta.is_complete is False
|
||||
|
||||
def test_with_text(self):
|
||||
"""Should accept text content"""
|
||||
delta = ParsedDelta(text="Hello world")
|
||||
assert delta.text == "Hello world"
|
||||
|
||||
def test_with_usage(self):
|
||||
"""Should accept usage dict"""
|
||||
delta = ParsedDelta(usage={"prompt_tokens": 10, "completion_tokens": 5})
|
||||
assert delta.usage["prompt_tokens"] == 10
|
||||
|
||||
def test_with_complete_flag(self):
|
||||
"""Should accept is_complete flag"""
|
||||
delta = ParsedDelta(is_complete=True)
|
||||
assert delta.is_complete is True
|
||||
|
||||
|
||||
class TestEndToEndStreaming:
|
||||
"""End-to-end streaming simulation tests"""
|
||||
|
||||
@pytest.fixture
|
||||
def adapter(self):
|
||||
from luxx.services.llm_adapters.openai_adapter import OpenAIAdapter
|
||||
return OpenAIAdapter()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_full_text_stream(self, adapter):
|
||||
"""Simulate full text response stream"""
|
||||
chunks = [
|
||||
'data: {"id":"msg_001","choices":[{"index":0,"delta":{"content":"用户","role":"assistant"}}]}',
|
||||
'data: {"id":"msg_001","choices":[{"index":0,"delta":{"content":"用中文","role":"assistant"}}]}',
|
||||
'data: {"id":"msg_001","choices":[{"index":0,"delta":{"content":"说","role":"assistant"}}]}',
|
||||
'data: {"id":"msg_001","choices":[{"index":0,"delta":{"content":"\"你好\"","role":"assistant"}}]}',
|
||||
'data: {"id":"msg_001","choices":[{"index":0,"delta":{"content":",这是","role":"assistant"}}]}',
|
||||
'data: {"id":"msg_001","choices":[{"index":0,"delta":{"content":"一个简单的","role":"assistant"}}]}',
|
||||
'data: {"id":"msg_001","choices":[{"index":0,"delta":{"content":"问候。","role":"assistant"}}]}',
|
||||
'data: {"id":"msg_001","choices":[{"index":0,"delta":{"content":"\n","role":"assistant"}}]}',
|
||||
'data: {"id":"msg_001","choices":[{"index":0,"delta":{"content":"我应该","role":"assistant"}}]}',
|
||||
'data: {"id":"msg_001","choices":[{"index":0,"delta":{"content":"用中文","role":"assistant"}}]}',
|
||||
'data: {"id":"msg_001","choices":[{"index":0,"delta":{"content":"友好地","role":"assistant"}}]}',
|
||||
'data: {"id":"msg_001","choices":[{"index":0,"delta":{"content":"回应。","role":"assistant"}}]}',
|
||||
'data: {"id":"msg_001","choices":[{"index":0,"delta":{"content":"\n\n","role":"assistant"}}]}',
|
||||
'data: {"id":"msg_001","choices":[{"index":0,"delta":{"content":"","role":"assistant"}}]}',
|
||||
'data: {"id":"msg_001","choices":[{"index":0,"delta":{"content":"\n\n你好!","role":"assistant"}}]}',
|
||||
'data: {"id":"msg_001","choices":[{"index":0,"delta":{"content":"有什么","role":"assistant"}}]}',
|
||||
'data: {"id":"msg_001","choices":[{"index":0,"delta":{"content":"我可以","role":"assistant"}}]}',
|
||||
'data: {"id":"msg_001","choices":[{"index":0,"delta":{"content":"帮助你的吗?","role":"assistant"}}]}',
|
||||
'data: {"id":"msg_001","choices":[{"finish_reason":"stop","index":0,"delta":{"role":"assistant"}}]}',
|
||||
]
|
||||
|
||||
full_text = ""
|
||||
is_complete = False
|
||||
|
||||
for chunk in chunks:
|
||||
deltas = [d async for d in adapter.parse_stream_chunk(chunk)]
|
||||
for delta in deltas:
|
||||
if delta.is_complete:
|
||||
is_complete = True
|
||||
if delta.text:
|
||||
full_text += delta.text
|
||||
|
||||
expected = "用户用中文说\"你好\",这是一个简单的问候。我应该用中文友好地回应。\n\n你好!有什么我可以帮助你的吗?"
|
||||
assert full_text == expected
|
||||
assert is_complete is True
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_empty_stream_between_finish(self, adapter):
|
||||
"""Handle empty content chunks before finish"""
|
||||
chunks = [
|
||||
'data: {"id":"msg_001","choices":[{"index":0,"delta":{"content":"Hello","role":"assistant"}}]}',
|
||||
'data: {"id":"msg_001","choices":[{"index":0,"delta":{"content":"","role":"assistant"}}]}',
|
||||
'data: {"id":"msg_001","choices":[{"index":0,"delta":{"content":"","role":"assistant"}}]}',
|
||||
'data: {"id":"msg_001","choices":[{"finish_reason":"stop","index":0,"delta":{"role":"assistant"}}]}',
|
||||
]
|
||||
|
||||
full_text = ""
|
||||
for chunk in chunks:
|
||||
deltas = [d async for d in adapter.parse_stream_chunk(chunk)]
|
||||
for delta in deltas:
|
||||
if delta.text:
|
||||
full_text += delta.text
|
||||
|
||||
assert full_text == "Hello"
|
||||
|
||||
|
||||
class TestToolCallParsing:
|
||||
"""Tests for tool call parsing"""
|
||||
|
||||
@pytest.fixture
|
||||
def adapter(self):
|
||||
from luxx.services.llm_adapters.openai_adapter import OpenAIAdapter
|
||||
return OpenAIAdapter()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_tool_call_chunk(self, adapter):
|
||||
"""Parse tool call chunk"""
|
||||
chunk = json.dumps({
|
||||
"id": "chatcmpl_001",
|
||||
"choices": [{
|
||||
"index": 0,
|
||||
"delta": {
|
||||
"tool_calls": [{
|
||||
"index": 0,
|
||||
"id": "call_abc123",
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "web_search",
|
||||
"arguments": '{"query":'
|
||||
}
|
||||
}]
|
||||
}
|
||||
}]
|
||||
})
|
||||
|
||||
# The current adapter doesn't yield tool_call deltas
|
||||
# This test documents current behavior
|
||||
deltas = [d async for d in adapter.parse_stream_chunk(chunk)]
|
||||
# Content is empty, no delta yielded
|
||||
assert len(deltas) == 0
|
||||
|
||||
def test_tool_call_accumulation(self):
|
||||
"""Simulate tool call argument accumulation"""
|
||||
chunks = [
|
||||
{"function": {"name": "python_eval", "arguments": "{"}},
|
||||
{"function": {"name": "", "arguments": '"expr": '}},
|
||||
{"function": {"name": "", "arguments": '"1 + 1"'}},
|
||||
{"function": {"name": "", "arguments": "}"}},
|
||||
]
|
||||
|
||||
accumulated = {"name": "", "arguments": ""}
|
||||
for c in chunks:
|
||||
if c["function"]["name"]:
|
||||
accumulated["name"] += c["function"]["name"]
|
||||
if c["function"]["arguments"]:
|
||||
accumulated["arguments"] += c["function"]["arguments"]
|
||||
|
||||
assert accumulated["name"] == "python_eval"
|
||||
assert accumulated["arguments"] == '{"expr": "1 + 1"}'
|
||||
|
|
@ -0,0 +1,48 @@
|
|||
"""Tests for config module"""
|
||||
|
||||
class TestConfig:
|
||||
"""Tests for Config class"""
|
||||
|
||||
def test_config_singleton(self):
|
||||
"""Should return same instance"""
|
||||
from luxx.config import config, Config
|
||||
config1 = Config()
|
||||
config2 = Config()
|
||||
assert config1 is config2
|
||||
assert config is config1
|
||||
|
||||
def test_get_with_default(self):
|
||||
"""Should return default value for missing key"""
|
||||
from luxx.config import config
|
||||
result = config.get("nonexistent.key", "default_value")
|
||||
assert result == "default_value"
|
||||
|
||||
def test_get_nested_key(self):
|
||||
"""Should support dot-separated keys"""
|
||||
from luxx.config import config
|
||||
# These should return configured or default values
|
||||
secret = config.get("app.secret_key")
|
||||
assert secret is not None
|
||||
|
||||
def test_properties_have_defaults(self):
|
||||
"""All properties should have sensible defaults"""
|
||||
from luxx.config import config
|
||||
assert isinstance(config.debug, bool)
|
||||
assert isinstance(config.app_host, str)
|
||||
assert isinstance(config.app_port, int)
|
||||
assert isinstance(config.database_url, str)
|
||||
assert config.app_port == 8000
|
||||
|
||||
def test_tools_config_properties(self):
|
||||
"""Tools configuration properties should work"""
|
||||
from luxx.config import config
|
||||
assert config.tools_enable_cache is not None
|
||||
assert config.tools_cache_ttl > 0
|
||||
assert config.tools_max_workers > 0
|
||||
assert config.tools_max_iterations > 0
|
||||
|
||||
def test_llm_config_properties(self):
|
||||
"""LLM configuration properties should work"""
|
||||
from luxx.config import config
|
||||
assert config.llm_provider is not None
|
||||
assert config.llm_api_url is not None
|
||||
|
|
@ -0,0 +1,110 @@
|
|||
"""Tests for utils/helpers module"""
|
||||
from luxx.utils.helpers import (
|
||||
generate_id,
|
||||
hash_password,
|
||||
verify_password,
|
||||
create_access_token,
|
||||
decode_access_token,
|
||||
success_response,
|
||||
error_response
|
||||
)
|
||||
|
||||
|
||||
class TestGenerateId:
|
||||
"""Tests for generate_id function"""
|
||||
|
||||
def test_generate_id_returns_string(self):
|
||||
"""Should return a string"""
|
||||
result = generate_id()
|
||||
assert isinstance(result, str)
|
||||
|
||||
def test_generate_id_with_prefix(self):
|
||||
"""Should return id with prefix"""
|
||||
result = generate_id("task")
|
||||
assert result.startswith("task_")
|
||||
|
||||
def test_generate_id_unique(self):
|
||||
"""Should generate unique ids"""
|
||||
ids = [generate_id() for _ in range(100)]
|
||||
assert len(set(ids)) == 100
|
||||
|
||||
|
||||
class TestPasswordHashing:
|
||||
"""Tests for password hashing functions"""
|
||||
|
||||
def test_hash_password_returns_string(self):
|
||||
"""Should return a hashed string"""
|
||||
password = "test_password_123"
|
||||
hashed = hash_password(password)
|
||||
assert isinstance(hashed, str)
|
||||
assert hashed != password
|
||||
|
||||
def test_verify_password_correct(self):
|
||||
"""Should return True for correct password"""
|
||||
password = "test_password_123"
|
||||
hashed = hash_password(password)
|
||||
assert verify_password(password, hashed) is True
|
||||
|
||||
def test_verify_password_incorrect(self):
|
||||
"""Should return False for incorrect password"""
|
||||
password = "test_password_123"
|
||||
wrong_password = "wrong_password"
|
||||
hashed = hash_password(password)
|
||||
assert verify_password(wrong_password, hashed) is False
|
||||
|
||||
|
||||
class TestJWTToken:
|
||||
"""Tests for JWT token functions"""
|
||||
|
||||
def test_create_access_token_returns_string(self):
|
||||
"""Should return a JWT token string"""
|
||||
token = create_access_token({"user_id": 1})
|
||||
assert isinstance(token, str)
|
||||
assert len(token) > 0
|
||||
|
||||
def test_decode_access_token_valid(self):
|
||||
"""Should decode valid token"""
|
||||
payload = {"user_id": 1, "username": "test"}
|
||||
token = create_access_token(payload)
|
||||
decoded = decode_access_token(token)
|
||||
assert decoded is not None
|
||||
assert decoded["user_id"] == 1
|
||||
assert decoded["username"] == "test"
|
||||
|
||||
def test_decode_access_token_invalid(self):
|
||||
"""Should return None for invalid token"""
|
||||
result = decode_access_token("invalid.token.here")
|
||||
assert result is None
|
||||
|
||||
|
||||
class TestResponseWrappers:
|
||||
"""Tests for response wrapper functions"""
|
||||
|
||||
def test_success_response_format(self):
|
||||
"""Should return correct success format"""
|
||||
result = success_response({"key": "value"}, "Success message")
|
||||
assert result["success"] is True
|
||||
assert result["message"] == "Success message"
|
||||
assert result["data"] == {"key": "value"}
|
||||
|
||||
def test_success_response_default(self):
|
||||
"""Should use default values"""
|
||||
result = success_response()
|
||||
assert result["success"] is True
|
||||
assert result["message"] == "Success"
|
||||
assert result["data"] is None
|
||||
|
||||
def test_error_response_format(self):
|
||||
"""Should return correct error format"""
|
||||
result = error_response("Error occurred", code=404)
|
||||
assert result["success"] is False
|
||||
assert result["message"] == "Error occurred"
|
||||
assert result["code"] == 404
|
||||
assert "errors" not in result
|
||||
|
||||
def test_error_response_with_errors(self):
|
||||
"""Should include errors field"""
|
||||
errors = {"field": ["required"]}
|
||||
result = error_response("Validation failed", code=400, errors=errors)
|
||||
assert result["success"] is False
|
||||
assert result["errors"] == errors
|
||||
|
|
@ -0,0 +1,223 @@
|
|||
"""Tests for task module"""
|
||||
import pytest
|
||||
from luxx.services.task import (
|
||||
Task,
|
||||
Step,
|
||||
TaskGraph,
|
||||
TaskService,
|
||||
TaskStatus,
|
||||
StepStatus,
|
||||
task_service
|
||||
)
|
||||
|
||||
|
||||
class TestStep:
|
||||
"""Tests for Step dataclass"""
|
||||
|
||||
def test_step_creation(self):
|
||||
"""Should create step with required fields"""
|
||||
step = Step(id="step_1", name="Test Step")
|
||||
assert step.id == "step_1"
|
||||
assert step.name == "Test Step"
|
||||
assert step.status == StepStatus.PENDING
|
||||
assert step.depends_on == []
|
||||
|
||||
def test_step_with_dependencies(self):
|
||||
"""Should create step with dependencies"""
|
||||
step = Step(
|
||||
id="step_2",
|
||||
name="Dependent Step",
|
||||
depends_on=["step_1"]
|
||||
)
|
||||
assert "step_1" in step.depends_on
|
||||
|
||||
def test_step_to_dict(self):
|
||||
"""Should convert step to dict"""
|
||||
step = Step(id="step_1", name="Test")
|
||||
result = step.to_dict()
|
||||
assert isinstance(result, dict)
|
||||
assert result["id"] == "step_1"
|
||||
assert result["name"] == "Test"
|
||||
assert result["status"] == "pending"
|
||||
|
||||
|
||||
class TestTask:
|
||||
"""Tests for Task dataclass"""
|
||||
|
||||
def test_task_creation(self):
|
||||
"""Should create task with required fields"""
|
||||
task = Task(id="task_1", name="Test Task")
|
||||
assert task.id == "task_1"
|
||||
assert task.name == "Test Task"
|
||||
assert task.status == TaskStatus.PENDING
|
||||
assert task.steps == []
|
||||
|
||||
def test_task_with_steps(self):
|
||||
"""Should create task with steps"""
|
||||
step1 = Step(id="step_1", name="Step 1")
|
||||
step2 = Step(id="step_2", name="Step 2")
|
||||
task = Task(id="task_1", name="Test", steps=[step1, step2])
|
||||
assert len(task.steps) == 2
|
||||
|
||||
def test_task_to_dict(self):
|
||||
"""Should convert task to dict"""
|
||||
task = Task(id="task_1", name="Test", goal="Complete task")
|
||||
result = task.to_dict()
|
||||
assert result["id"] == "task_1"
|
||||
assert result["goal"] == "Complete task"
|
||||
assert result["status"] == "pending"
|
||||
|
||||
|
||||
class TestTaskGraph:
|
||||
"""Tests for TaskGraph class"""
|
||||
|
||||
def test_graph_creation(self):
|
||||
"""Should create graph from task"""
|
||||
task = Task(id="task_1", name="Test")
|
||||
graph = TaskGraph(task)
|
||||
assert graph.task is task
|
||||
|
||||
def test_topological_sort_no_dependencies(self):
|
||||
"""Should sort steps without dependencies"""
|
||||
step1 = Step(id="step_1", name="Step 1")
|
||||
step2 = Step(id="step_2", name="Step 2")
|
||||
task = Task(id="task_1", name="Test", steps=[step1, step2])
|
||||
graph = TaskGraph(task)
|
||||
sorted_steps = graph.topological_sort()
|
||||
assert len(sorted_steps) == 2
|
||||
|
||||
def test_topological_sort_with_dependencies(self):
|
||||
"""Should respect dependencies in sort"""
|
||||
step1 = Step(id="step_1", name="Step 1")
|
||||
step2 = Step(id="step_2", name="Step 2", depends_on=["step_1"])
|
||||
task = Task(id="task_1", name="Test", steps=[step1, step2])
|
||||
graph = TaskGraph(task)
|
||||
sorted_steps = graph.topological_sort()
|
||||
ids = [s.id for s in sorted_steps]
|
||||
assert ids.index("step_1") < ids.index("step_2")
|
||||
|
||||
def test_get_ready_steps(self):
|
||||
"""Should return steps ready to execute"""
|
||||
step1 = Step(id="step_1", name="Step 1")
|
||||
step2 = Step(id="step_2", name="Step 2", depends_on=["step_1"])
|
||||
task = Task(id="task_1", name="Test", steps=[step1, step2])
|
||||
graph = TaskGraph(task)
|
||||
ready = graph.get_ready_steps([])
|
||||
assert len(ready) == 1
|
||||
assert ready[0].id == "step_1"
|
||||
|
||||
def test_get_ready_steps_after_completion(self):
|
||||
"""Should return dependent steps after completion"""
|
||||
step1 = Step(id="step_1", name="Step 1")
|
||||
step2 = Step(id="step_2", name="Step 2", depends_on=["step_1"])
|
||||
task = Task(id="task_1", name="Test", steps=[step1, step2])
|
||||
graph = TaskGraph(task)
|
||||
ready = graph.get_ready_steps(["step_1"])
|
||||
assert len(ready) == 1
|
||||
assert ready[0].id == "step_2"
|
||||
|
||||
def test_detect_cycles_no_cycle(self):
|
||||
"""Should return empty for no cycles"""
|
||||
step1 = Step(id="step_1", name="Step 1")
|
||||
step2 = Step(id="step_2", name="Step 2")
|
||||
task = Task(id="task_1", name="Test", steps=[step1, step2])
|
||||
graph = TaskGraph(task)
|
||||
cycles = graph.detect_cycles()
|
||||
assert cycles == []
|
||||
|
||||
def test_detect_cycles_with_cycle(self):
|
||||
"""Should detect circular dependency"""
|
||||
step1 = Step(id="step_1", name="Step 1", depends_on=["step_2"])
|
||||
step2 = Step(id="step_2", name="Step 2", depends_on=["step_1"])
|
||||
task = Task(id="task_1", name="Test", steps=[step1, step2])
|
||||
graph = TaskGraph(task)
|
||||
cycles = graph.detect_cycles()
|
||||
assert len(cycles) > 0
|
||||
|
||||
def test_validate_valid_graph(self):
|
||||
"""Should validate valid graph"""
|
||||
step1 = Step(id="step_1", name="Step 1")
|
||||
task = Task(id="task_1", name="Test", steps=[step1])
|
||||
graph = TaskGraph(task)
|
||||
is_valid, error = graph.validate()
|
||||
assert is_valid is True
|
||||
assert error is None
|
||||
|
||||
def test_validate_invalid_dependency(self):
|
||||
"""Should fail validation for invalid dependency"""
|
||||
step1 = Step(id="step_1", name="Step 1", depends_on=["nonexistent"])
|
||||
task = Task(id="task_1", name="Test", steps=[step1])
|
||||
graph = TaskGraph(task)
|
||||
is_valid, error = graph.validate()
|
||||
assert is_valid is False
|
||||
assert error is not None
|
||||
|
||||
|
||||
class TestTaskService:
|
||||
"""Tests for TaskService class"""
|
||||
|
||||
def test_create_task(self):
|
||||
"""Should create a new task"""
|
||||
service = TaskService()
|
||||
task = service.create_task(name="Test Task", goal="Complete test")
|
||||
assert task is not None
|
||||
assert task.name == "Test Task"
|
||||
assert task.goal == "Complete test"
|
||||
|
||||
def test_create_task_with_steps(self):
|
||||
"""Should create task with steps"""
|
||||
service = TaskService()
|
||||
steps = [
|
||||
{"name": "Step 1", "description": "First step"},
|
||||
{"name": "Step 2", "description": "Second step"}
|
||||
]
|
||||
task = service.create_task(name="Test", goal="Goal", steps=steps)
|
||||
assert len(task.steps) == 2
|
||||
|
||||
def test_get_task(self):
|
||||
"""Should retrieve task by id"""
|
||||
service = TaskService()
|
||||
created = service.create_task(name="Test", goal="Goal")
|
||||
retrieved = service.get_task(created.id)
|
||||
assert retrieved is not None
|
||||
assert retrieved.id == created.id
|
||||
|
||||
def test_get_nonexistent_task(self):
|
||||
"""Should return None for nonexistent task"""
|
||||
service = TaskService()
|
||||
result = service.get_task("nonexistent_id")
|
||||
assert result is None
|
||||
|
||||
def test_update_task_status(self):
|
||||
"""Should update task status"""
|
||||
service = TaskService()
|
||||
task = service.create_task(name="Test", goal="Goal")
|
||||
updated = service.update_task_status(task.id, TaskStatus.RUNNING)
|
||||
assert updated is not None
|
||||
assert updated.status == TaskStatus.RUNNING
|
||||
|
||||
def test_add_steps(self):
|
||||
"""Should add steps to existing task"""
|
||||
service = TaskService()
|
||||
task = service.create_task(name="Test", goal="Goal")
|
||||
steps = [{"name": "New Step"}]
|
||||
added = service.add_steps(task.id, steps)
|
||||
assert added is not None
|
||||
assert len(added) == 1
|
||||
assert len(task.steps) == 1
|
||||
|
||||
def test_delete_task(self):
|
||||
"""Should delete task"""
|
||||
service = TaskService()
|
||||
task = service.create_task(name="Test", goal="Goal")
|
||||
result = service.delete_task(task.id)
|
||||
assert result is True
|
||||
assert service.get_task(task.id) is None
|
||||
|
||||
def test_build_graph(self):
|
||||
"""Should build graph for task"""
|
||||
service = TaskService()
|
||||
task = service.create_task(name="Test", goal="Goal")
|
||||
graph = service.build_graph(task.id)
|
||||
assert graph is not None
|
||||
assert isinstance(graph, TaskGraph)
|
||||
|
|
@ -0,0 +1,313 @@
|
|||
"""Tests for tools module"""
|
||||
import pytest
|
||||
from luxx.tools.core import (
|
||||
ToolContext,
|
||||
ToolDefinition,
|
||||
ToolResult,
|
||||
ToolRegistry,
|
||||
CommandPermission
|
||||
)
|
||||
|
||||
|
||||
class TestToolContext:
|
||||
"""Tests for ToolContext dataclass"""
|
||||
|
||||
def test_tool_context_creation(self):
|
||||
"""Should create context with default values"""
|
||||
ctx = ToolContext()
|
||||
assert ctx.workspace is None
|
||||
assert ctx.user_id is None
|
||||
assert ctx.username is None
|
||||
assert ctx.extra == {}
|
||||
|
||||
def test_tool_context_with_values(self):
|
||||
"""Should create context with provided values"""
|
||||
ctx = ToolContext(
|
||||
workspace="/workspace/test",
|
||||
user_id=1,
|
||||
username="testuser",
|
||||
extra={"key": "value"}
|
||||
)
|
||||
assert ctx.workspace == "/workspace/test"
|
||||
assert ctx.user_id == 1
|
||||
assert ctx.username == "testuser"
|
||||
assert ctx.extra["key"] == "value"
|
||||
|
||||
|
||||
class TestToolDefinition:
|
||||
"""Tests for ToolDefinition dataclass"""
|
||||
|
||||
def test_tool_definition_creation(self):
|
||||
"""Should create tool definition"""
|
||||
def handler(args):
|
||||
return {"result": "ok"}
|
||||
|
||||
tool = ToolDefinition(
|
||||
name="test_tool",
|
||||
description="A test tool",
|
||||
parameters={"type": "object"},
|
||||
handler=handler
|
||||
)
|
||||
assert tool.name == "test_tool"
|
||||
assert tool.description == "A test tool"
|
||||
assert tool.category == "general"
|
||||
assert tool.required_permission == CommandPermission.READ_ONLY
|
||||
|
||||
def test_tool_definition_to_openai_format(self):
|
||||
"""Should convert to OpenAI format"""
|
||||
def handler(args):
|
||||
return {"result": "ok"}
|
||||
|
||||
tool = ToolDefinition(
|
||||
name="test_tool",
|
||||
description="A test tool",
|
||||
parameters={"type": "object", "properties": {}},
|
||||
handler=handler
|
||||
)
|
||||
result = tool.to_openai_format()
|
||||
assert result["type"] == "function"
|
||||
assert result["function"]["name"] == "test_tool"
|
||||
|
||||
|
||||
class TestToolResult:
|
||||
"""Tests for ToolResult dataclass"""
|
||||
|
||||
def test_tool_result_success(self):
|
||||
"""Should create success result"""
|
||||
result = ToolResult(success=True, data={"key": "value"})
|
||||
assert result.success is True
|
||||
assert result.data["key"] == "value"
|
||||
assert result.error is None
|
||||
|
||||
def test_tool_result_failure(self):
|
||||
"""Should create failure result"""
|
||||
result = ToolResult(success=False, error="Something went wrong")
|
||||
assert result.success is False
|
||||
assert result.error == "Something went wrong"
|
||||
|
||||
def test_tool_result_to_dict(self):
|
||||
"""Should convert to dictionary"""
|
||||
result = ToolResult(success=True, data={"key": "value"})
|
||||
d = result.to_dict()
|
||||
assert isinstance(d, dict)
|
||||
assert d["success"] is True
|
||||
assert d["data"]["key"] == "value"
|
||||
|
||||
def test_tool_result_ok_factory(self):
|
||||
"""Should use ok() factory method"""
|
||||
result = ToolResult.ok({"result": "success"})
|
||||
assert result.success is True
|
||||
assert result.data == {"result": "success"}
|
||||
|
||||
def test_tool_result_fail_factory(self):
|
||||
"""Should use fail() factory method"""
|
||||
result = ToolResult.fail("Error occurred")
|
||||
assert result.success is False
|
||||
assert result.error == "Error occurred"
|
||||
|
||||
|
||||
class TestToolRegistry:
|
||||
"""Tests for ToolRegistry class"""
|
||||
|
||||
def test_registry_singleton(self):
|
||||
"""Should return same instance"""
|
||||
reg1 = ToolRegistry()
|
||||
reg2 = ToolRegistry()
|
||||
assert reg1 is reg2
|
||||
|
||||
def test_register_tool(self):
|
||||
"""Should register a tool"""
|
||||
registry = ToolRegistry()
|
||||
registry.clear() # Start fresh
|
||||
|
||||
def handler(args):
|
||||
return {"result": "ok"}
|
||||
|
||||
tool = ToolDefinition(
|
||||
name="my_tool",
|
||||
description="My test tool",
|
||||
parameters={},
|
||||
handler=handler
|
||||
)
|
||||
registry.register(tool)
|
||||
assert registry.get("my_tool") is not None
|
||||
assert registry.tool_count() == 1
|
||||
|
||||
def test_get_nonexistent_tool(self):
|
||||
"""Should return None for nonexistent tool"""
|
||||
registry = ToolRegistry()
|
||||
registry.clear()
|
||||
assert registry.get("nonexistent") is None
|
||||
|
||||
def test_list_all_tools(self):
|
||||
"""Should list all registered tools"""
|
||||
registry = ToolRegistry()
|
||||
registry.clear()
|
||||
|
||||
def handler(args):
|
||||
return {}
|
||||
|
||||
tool1 = ToolDefinition(name="tool1", description="Tool 1", parameters={}, handler=handler)
|
||||
tool2 = ToolDefinition(name="tool2", description="Tool 2", parameters={}, handler=handler)
|
||||
registry.register(tool1)
|
||||
registry.register(tool2)
|
||||
|
||||
tools = registry.list_all()
|
||||
assert len(tools) == 2
|
||||
|
||||
def test_list_by_category(self):
|
||||
"""Should filter tools by category"""
|
||||
registry = ToolRegistry()
|
||||
registry.clear()
|
||||
|
||||
def handler(args):
|
||||
return {}
|
||||
|
||||
tool1 = ToolDefinition(
|
||||
name="tool1", description="Tool 1", parameters={},
|
||||
handler=handler, category="code"
|
||||
)
|
||||
tool2 = ToolDefinition(
|
||||
name="tool2", description="Tool 2", parameters={},
|
||||
handler=handler, category="file"
|
||||
)
|
||||
registry.register(tool1)
|
||||
registry.register(tool2)
|
||||
|
||||
code_tools = registry.list_by_category("code")
|
||||
assert len(code_tools) == 1
|
||||
|
||||
def test_execute_tool(self):
|
||||
"""Should execute a tool"""
|
||||
registry = ToolRegistry()
|
||||
registry.clear()
|
||||
|
||||
def handler(args):
|
||||
return {"executed": True, "args": args}
|
||||
|
||||
tool = ToolDefinition(
|
||||
name="test_tool",
|
||||
description="Test tool",
|
||||
parameters={},
|
||||
handler=handler
|
||||
)
|
||||
registry.register(tool)
|
||||
|
||||
result = registry.execute("test_tool", {"input": "value"})
|
||||
# Direct handler returns are passed through as-is
|
||||
assert result["executed"] is True
|
||||
assert result["args"]["input"] == "value"
|
||||
|
||||
def test_execute_tool_with_tool_result(self):
|
||||
"""Should return ToolResult when handler returns ToolResult"""
|
||||
registry = ToolRegistry()
|
||||
registry.clear()
|
||||
|
||||
def handler(args):
|
||||
return ToolResult.ok({"executed": True})
|
||||
|
||||
tool = ToolDefinition(
|
||||
name="test_tool",
|
||||
description="Test tool",
|
||||
parameters={},
|
||||
handler=handler
|
||||
)
|
||||
registry.register(tool)
|
||||
|
||||
result = registry.execute("test_tool", {})
|
||||
assert result["success"] is True
|
||||
assert result["data"]["executed"] is True
|
||||
|
||||
def test_execute_nonexistent_tool(self):
|
||||
"""Should return error for nonexistent tool"""
|
||||
registry = ToolRegistry()
|
||||
registry.clear()
|
||||
|
||||
result = registry.execute("nonexistent", {})
|
||||
assert result["success"] is False
|
||||
assert "not found" in result["error"]
|
||||
|
||||
def test_execute_with_context(self):
|
||||
"""Should pass context to handler"""
|
||||
registry = ToolRegistry()
|
||||
registry.clear()
|
||||
received_context = None
|
||||
|
||||
def handler(args, context=None):
|
||||
nonlocal received_context
|
||||
received_context = context
|
||||
return ToolResult.ok({"received": True})
|
||||
|
||||
tool = ToolDefinition(
|
||||
name="test_tool",
|
||||
description="Test tool",
|
||||
parameters={},
|
||||
handler=handler
|
||||
)
|
||||
registry.register(tool)
|
||||
|
||||
ctx = ToolContext(user_id=1, username="test")
|
||||
registry.execute("test_tool", {}, context=ctx)
|
||||
assert received_context is not None
|
||||
assert received_context.user_id == 1
|
||||
|
||||
def test_permission_check(self):
|
||||
"""Should check user permission"""
|
||||
registry = ToolRegistry()
|
||||
registry.clear()
|
||||
|
||||
def handler(args):
|
||||
return ToolResult.ok({"ok": True})
|
||||
|
||||
tool = ToolDefinition(
|
||||
name="admin_tool",
|
||||
description="Admin tool",
|
||||
parameters={},
|
||||
handler=handler,
|
||||
required_permission=CommandPermission.ADMIN
|
||||
)
|
||||
registry.register(tool)
|
||||
|
||||
# User with low permission
|
||||
ctx = ToolContext(
|
||||
user_id=1,
|
||||
extra={"user_permission_level": CommandPermission.READ_ONLY}
|
||||
)
|
||||
result = registry.execute("admin_tool", {}, context=ctx)
|
||||
assert result["success"] is False
|
||||
assert "Permission denied" in result["error"]
|
||||
|
||||
def test_remove_tool(self):
|
||||
"""Should remove a tool"""
|
||||
registry = ToolRegistry()
|
||||
registry.clear()
|
||||
|
||||
def handler(args):
|
||||
return {}
|
||||
|
||||
tool = ToolDefinition(
|
||||
name="removable",
|
||||
description="To be removed",
|
||||
parameters={},
|
||||
handler=handler
|
||||
)
|
||||
registry.register(tool)
|
||||
assert registry.get("removable") is not None
|
||||
|
||||
registry.remove("removable")
|
||||
assert registry.get("removable") is None
|
||||
|
||||
def test_clear_tools(self):
|
||||
"""Should clear all tools"""
|
||||
registry = ToolRegistry()
|
||||
registry.clear()
|
||||
|
||||
def handler(args):
|
||||
return {}
|
||||
|
||||
tool = ToolDefinition(name="tool1", description="", parameters={}, handler=handler)
|
||||
registry.register(tool)
|
||||
assert registry.tool_count() > 0
|
||||
|
||||
registry.clear()
|
||||
assert registry.tool_count() == 0
|
||||
Loading…
Reference in New Issue