Compare commits

...

2 Commits

Author SHA1 Message Date
ViperEkura dc08267c15 feat: 增加task实现 2026-04-17 23:05:54 +08:00
ViperEkura f10909bec3 chore: 优化chat部分结构 2026-04-17 23:01:48 +08:00
2 changed files with 489 additions and 145 deletions

View File

@ -45,6 +45,157 @@ def get_llm_client(conversation: Conversation = None):
return client, max_tokens return client, max_tokens
class StreamContext:
"""Context for streaming response state management."""
def __init__(
self,
step_index: int = 0,
current_step_id: str = None,
current_step_idx: int = None,
current_stream_type: str = None,
full_content: str = "",
full_thinking: str = ""
):
self.step_index = step_index
self.current_step_id = current_step_id
self.current_step_idx = current_step_idx
self.current_stream_type = current_stream_type
self.full_content = full_content
self.full_thinking = full_thinking
self.all_steps = []
self.all_tool_calls = []
self.all_tool_results = []
self.tool_calls_list = []
def reset_iteration(self):
"""Reset streaming step tracker for new iteration."""
self.current_step_id = None
self.current_step_idx = None
self.current_stream_type = None
self.full_content = ""
self.full_thinking = ""
self.tool_calls_list = []
def start_stream_step(self, step_type: str) -> str:
"""Start a new streaming step. Returns the step_id."""
self.current_step_idx = self.step_index
self.current_step_id = f"step-{self.step_index}"
self.current_stream_type = step_type
self.step_index += 1
return self.current_step_id
def yield_stream_step(self, step_type: str, content: str) -> Dict[str, Any]:
"""Yield a streaming step event."""
return _sse_event("process_step", {
"step": {
"id": self.current_step_id,
"index": self.current_step_idx,
"type": step_type,
"content": content
}
})
def save_streaming_step(self):
"""Save the current streaming step to all_steps."""
if self.current_step_id is None:
return
if self.current_stream_type == "thinking":
self.all_steps.append({
"id": self.current_step_id,
"index": self.current_step_idx,
"type": "thinking",
"content": self.full_thinking
})
elif self.current_stream_type == "text":
self.all_steps.append({
"id": self.current_step_id,
"index": self.current_step_idx,
"type": "text",
"content": self.full_content
})
def handle_thinking_stream(self, delta: Dict) -> Optional[Dict]:
"""Handle reasoning/thinking delta. Returns yield_obj if step was yielded."""
reasoning = delta.get("reasoning_content", "")
if not reasoning:
return None
prev_len = len(self.full_thinking)
self.full_thinking += reasoning
if prev_len == 0: # New thinking stream started
self.start_stream_step("thinking")
return self.yield_stream_step("thinking", self.full_thinking)
def handle_text_stream(self, delta: Dict) -> Optional[Dict]:
"""Handle content delta. Returns yield_obj if step was yielded."""
content = delta.get("content", "")
if not content:
return None
prev_len = len(self.full_content)
self.full_content += content
if prev_len == 0: # New text stream started
self.start_stream_step("text")
return self.yield_stream_step("text", self.full_content)
def handle_tool_call(self) -> tuple:
"""Handle tool calls. Returns (tool_call_step_ids, tool_call_steps, yield_objs)."""
tool_call_step_ids = []
tool_call_steps = []
yield_objs = []
for tc in self.tool_calls_list:
call_step_idx = self.step_index
call_step_id = f"step-{self.step_index}"
tool_call_step_ids.append(call_step_id)
self.step_index += 1
call_step = {
"id": call_step_id,
"index": call_step_idx,
"type": "tool_call",
"id_ref": tc.get("id", ""),
"name": tc["function"]["name"],
"arguments": tc["function"]["arguments"]
}
tool_call_steps.append(call_step)
yield_objs.append(_sse_event("process_step", {"step": call_step}))
return tool_call_step_ids, tool_call_steps, yield_objs
def handle_tool_result(self, tool_result: Dict, tool_call_step_id: str) -> tuple:
"""Handle single tool result. Returns (result_step, yield_obj)."""
result_step_idx = self.step_index
result_step_id = f"step-{self.step_index}"
self.step_index += 1
content = tool_result.get("content", "")
success = True
try:
content_obj = json.loads(content)
if isinstance(content_obj, dict):
success = content_obj.get("success", True)
except:
pass
result_step = {
"id": result_step_id,
"index": result_step_idx,
"type": "tool_result",
"id_ref": tool_call_step_id,
"name": tool_result.get("name", ""),
"content": content,
"success": success
}
return result_step, _sse_event("process_step", {"step": result_step})
class ChatService: class ChatService:
"""Chat service with tool support""" """Chat service with tool support"""
@ -129,12 +280,6 @@ class ChatService:
# 直接使用 provider 的 max_tokens # 直接使用 provider 的 max_tokens
max_tokens = provider_max_tokens max_tokens = provider_max_tokens
# State tracking
all_steps = []
all_tool_calls = []
all_tool_results = []
step_index = 0
# Token usage tracking # Token usage tracking
total_usage = { total_usage = {
"prompt_tokens": 0, "prompt_tokens": 0,
@ -142,23 +287,12 @@ class ChatService:
"total_tokens": 0 "total_tokens": 0
} }
# Global step IDs for thinking and text (persist across iterations) # Streaming context for state management
thinking_step_id = None ctx = StreamContext()
thinking_step_idx = None
text_step_id = None
text_step_idx = None
for iteration in range(MAX_ITERATIONS): for iteration in range(MAX_ITERATIONS):
# Stream from LLM # Reset streaming context for this iteration
full_content = "" ctx.reset_iteration()
full_thinking = ""
tool_calls_list = []
# Step tracking - use unified step-{index} format
thinking_step_id = None
thinking_step_idx = None
text_step_id = None
text_step_idx = None
async for sse_line in llm.stream_call( async for sse_line in llm.stream_call(
model=model, model=model,
@ -218,19 +352,16 @@ class ChatService:
if chunk.get("content") or chunk.get("message"): if chunk.get("content") or chunk.get("message"):
content = chunk.get("content") or chunk.get("message", {}).get("content", "") content = chunk.get("content") or chunk.get("message", {}).get("content", "")
if content: if content:
# BUG FIX: Update full_content so it gets saved to database prev_len = len(ctx.full_content)
prev_content_len = len(full_content) ctx.full_content += content
full_content += content if prev_len == 0: # New text stream started
if prev_content_len == 0: # New text stream started ctx.start_stream_step("text")
text_step_idx = step_index
text_step_id = f"step-{step_index}"
step_index += 1
yield _sse_event("process_step", { yield _sse_event("process_step", {
"step": { "step": {
"id": text_step_id if prev_content_len == 0 else f"step-{step_index - 1}", "id": ctx.current_step_id if prev_len == 0 else f"step-{ctx.step_index - 1}",
"index": text_step_idx if prev_content_len == 0 else step_index - 1, "index": ctx.current_step_idx if prev_len == 0 else ctx.step_index - 1,
"type": "text", "type": "text",
"content": full_content # Always send accumulated content "content": ctx.full_content
} }
}) })
continue continue
@ -238,96 +369,43 @@ class ChatService:
delta = choices[0].get("delta", {}) delta = choices[0].get("delta", {})
# Handle reasoning (thinking) # Handle reasoning (thinking)
reasoning = delta.get("reasoning_content", "") yield_obj = ctx.handle_thinking_stream(delta)
if reasoning: if yield_obj:
prev_thinking_len = len(full_thinking) yield yield_obj
full_thinking += reasoning
if prev_thinking_len == 0: # New thinking stream started
thinking_step_idx = step_index
thinking_step_id = f"step-{step_index}"
step_index += 1
yield _sse_event("process_step", {
"step": {
"id": thinking_step_id,
"index": thinking_step_idx,
"type": "thinking",
"content": full_thinking
}
})
# Handle content # Handle content
content = delta.get("content", "") yield_obj = ctx.handle_text_stream(delta)
if content: if yield_obj:
prev_content_len = len(full_content) yield yield_obj
full_content += content
if prev_content_len == 0: # New text stream started
text_step_idx = step_index
text_step_id = f"step-{step_index}"
step_index += 1
yield _sse_event("process_step", {
"step": {
"id": text_step_id,
"index": text_step_idx,
"type": "text",
"content": full_content
}
})
# Accumulate tool calls # Accumulate tool calls
tool_calls_delta = delta.get("tool_calls", []) tool_calls_delta = delta.get("tool_calls", [])
for tc in tool_calls_delta: for tc in tool_calls_delta:
idx = tc.get("index", 0) idx = tc.get("index", 0)
if idx >= len(tool_calls_list): if idx >= len(ctx.tool_calls_list):
tool_calls_list.append({ ctx.tool_calls_list.append({
"id": tc.get("id", ""), "id": tc.get("id", ""),
"type": "function", "type": "function",
"function": {"name": "", "arguments": ""} "function": {"name": "", "arguments": ""}
}) })
func = tc.get("function", {}) func = tc.get("function", {})
if func.get("name"): if func.get("name"):
tool_calls_list[idx]["function"]["name"] += func["name"] ctx.tool_calls_list[idx]["function"]["name"] += func["name"]
if func.get("arguments"): if func.get("arguments"):
tool_calls_list[idx]["function"]["arguments"] += func["arguments"] ctx.tool_calls_list[idx]["function"]["arguments"] += func["arguments"]
# Save thinking step # Save streaming step (thinking or text)
if thinking_step_id is not None: ctx.save_streaming_step()
all_steps.append({
"id": thinking_step_id,
"index": thinking_step_idx,
"type": "thinking",
"content": full_thinking
})
# Save text step
if text_step_id is not None:
all_steps.append({
"id": text_step_id,
"index": text_step_idx,
"type": "text",
"content": full_content
})
# Handle tool calls # Handle tool calls
if tool_calls_list: if ctx.tool_calls_list:
all_tool_calls.extend(tool_calls_list) ctx.all_tool_calls.extend(ctx.tool_calls_list)
# Yield tool_call steps - use unified step-{index} format # Handle tool_call steps
tool_call_step_ids = [] # Track step IDs for tool calls tool_call_step_ids, tool_call_steps, yield_objs = ctx.handle_tool_call()
for tc in tool_calls_list: ctx.all_steps.extend(tool_call_steps)
call_step_idx = step_index for yield_obj in yield_objs:
call_step_id = f"step-{step_index}" yield yield_obj
tool_call_step_ids.append(call_step_id)
step_index += 1
call_step = {
"id": call_step_id,
"index": call_step_idx,
"type": "tool_call",
"id_ref": tc.get("id", ""),
"name": tc["function"]["name"],
"arguments": tc["function"]["arguments"]
}
all_steps.append(call_step)
yield _sse_event("process_step", {"step": call_step})
# Execute tools # Execute tools
tool_context = { tool_context = {
@ -337,39 +415,17 @@ class ChatService:
"user_permission_level": user_permission_level "user_permission_level": user_permission_level
} }
tool_results = self.tool_executor.process_tool_calls_parallel( tool_results = self.tool_executor.process_tool_calls_parallel(
tool_calls_list, tool_context ctx.tool_calls_list, tool_context
) )
# Yield tool_result steps - use unified step-{index} format # Handle tool_result steps
for i, tr in enumerate(tool_results): for i, tr in enumerate(tool_results):
tool_call_step_id = tool_call_step_ids[i] if i < len(tool_call_step_ids) else f"step-{i}" tool_call_step_id = tool_call_step_ids[i] if i < len(tool_call_step_ids) else f"step-{i}"
result_step_idx = step_index result_step, yield_obj = ctx.handle_tool_result(tr, tool_call_step_id)
result_step_id = f"step-{step_index}" ctx.all_steps.append(result_step)
step_index += 1 yield yield_obj
# 解析 content 中的 success 状态 ctx.all_tool_results.append({
content = tr.get("content", "")
success = True
try:
content_obj = json.loads(content)
if isinstance(content_obj, dict):
success = content_obj.get("success", True)
except:
pass
result_step = {
"id": result_step_id,
"index": result_step_idx,
"type": "tool_result",
"id_ref": tool_call_step_id, # Reference to the tool_call step
"name": tr.get("name", ""),
"content": content,
"success": success
}
all_steps.append(result_step)
yield _sse_event("process_step", {"step": result_step})
all_tool_results.append({
"role": "tool", "role": "tool",
"tool_call_id": tr.get("tool_call_id", ""), "tool_call_id": tr.get("tool_call_id", ""),
"content": tr.get("content", "") "content": tr.get("content", "")
@ -378,27 +434,27 @@ class ChatService:
# Add assistant message with tool calls for next iteration # Add assistant message with tool calls for next iteration
messages.append({ messages.append({
"role": "assistant", "role": "assistant",
"content": full_content or "", "content": ctx.full_content or "",
"tool_calls": tool_calls_list "tool_calls": ctx.tool_calls_list
}) })
messages.extend(all_tool_results[-len(tool_results):]) messages.extend(ctx.all_tool_results[-len(tool_results):])
all_tool_results = [] ctx.all_tool_results = []
continue continue
# No tool calls - final iteration, save message # No tool calls - final iteration, save message
msg_id = str(uuid.uuid4()) msg_id = str(uuid.uuid4())
# 使用 API 返回的真实 completion_tokens如果 API 没返回则降级使用估算值 # 使用 API 返回的真实 completion_tokens如果 API 没返回则降级使用估算值
actual_token_count = total_usage.get("completion_tokens", 0) or len(full_content) // 4 actual_token_count = total_usage.get("completion_tokens", 0) or len(ctx.full_content) // 4
logger.info(f"[TOKEN] total_usage: {total_usage}, actual_token_count: {actual_token_count}") logger.info(f"[TOKEN] total_usage: {total_usage}, actual_token_count: {actual_token_count}")
self._save_message( self._save_message(
conversation.id, conversation.id,
msg_id, msg_id,
full_content, ctx.full_content,
all_tool_calls, ctx.all_tool_calls,
all_tool_results, ctx.all_tool_results,
all_steps, ctx.all_steps,
actual_token_count, actual_token_count,
total_usage total_usage
) )
@ -411,15 +467,15 @@ class ChatService:
return return
# Max iterations exceeded - save message before error # Max iterations exceeded - save message before error
if full_content or all_tool_calls: if ctx.full_content or ctx.all_tool_calls:
msg_id = str(uuid.uuid4()) msg_id = str(uuid.uuid4())
self._save_message( self._save_message(
conversation.id, conversation.id,
msg_id, msg_id,
full_content, ctx.full_content,
all_tool_calls, ctx.all_tool_calls,
all_tool_results, ctx.all_tool_results,
all_steps, ctx.all_steps,
actual_token_count, actual_token_count,
total_usage total_usage
) )

288
luxx/services/task.py Normal file
View File

@ -0,0 +1,288 @@
"""Task module for autonomous task execution"""
from dataclasses import dataclass, field
from datetime import datetime
from enum import Enum
import logging
from typing import List, Optional, Dict, Any
from luxx.utils.helpers import generate_id
logger = logging.getLogger(__name__)
class TaskStatus(Enum):
"""Task status enum"""
PENDING = "pending"
READY = "ready"
RUNNING = "running"
BLOCK = "block"
TERMINATED = "terminated"
class StepStatus(Enum):
"""Step status enum"""
PENDING = "pending"
RUNNING = "running"
COMPLETED = "completed"
FAILED = "failed"
SKIPPED = "skipped"
@dataclass
class Step:
"""Task step"""
id: str
name: str
description: str = ""
depends_on: List[str] = field(default_factory=list)
status: StepStatus = StepStatus.PENDING
result: Optional[Dict[str, Any]] = None
created_at: datetime = field(default_factory=datetime.now)
updated_at: datetime = field(default_factory=datetime.now)
def to_dict(self) -> Dict[str, Any]:
"""Convert to dictionary"""
return {
"id": self.id,
"name": self.name,
"description": self.description,
"depends_on": self.depends_on,
"status": self.status.value,
"result": self.result,
"created_at": self.created_at.isoformat() if self.created_at else None,
"updated_at": self.updated_at.isoformat() if self.updated_at else None
}
@dataclass
class Task:
"""Task entity"""
id: str
name: str
description: str = ""
goal: str = ""
status: TaskStatus = TaskStatus.PENDING
steps: List[Step] = field(default_factory=list)
subtasks: List["Task"] = field(default_factory=list)
result: Optional[Dict[str, Any]] = None
created_at: datetime = field(default_factory=datetime.now)
updated_at: datetime = field(default_factory=datetime.now)
def to_dict(self) -> Dict[str, Any]:
"""Convert to dictionary"""
return {
"id": self.id,
"name": self.name,
"description": self.description,
"goal": self.goal,
"status": self.status.value,
"steps": [s.to_dict() for s in self.steps],
"subtasks": [t.to_dict() for t in self.subtasks],
"result": self.result,
"created_at": self.created_at.isoformat() if self.created_at else None,
"updated_at": self.updated_at.isoformat() if self.updated_at else None
}
class TaskGraph:
"""Task graph for managing step dependencies"""
def __init__(self, task: Task):
self.task = task
self._adjacency: Dict[str, List[str]] = {}
self._reverse_adjacency: Dict[str, List[str]] = {}
self._in_degree: Dict[str, int] = {}
self._build_graph()
def _build_graph(self) -> None:
"""Build graph from task steps"""
for step in self.task.steps:
self._adjacency[step.id] = []
self._reverse_adjacency[step.id] = []
self._in_degree[step.id] = 0
for step in self.task.steps:
for dep_id in step.depends_on:
if dep_id in self._adjacency:
self._adjacency[dep_id].append(step.id)
self._reverse_adjacency[step.id].append(dep_id)
self._in_degree[step.id] += 1
def topological_sort(self) -> List[Step]:
"""Get steps in topological order"""
in_degree = self._in_degree.copy()
queue = [step_id for step_id, degree in in_degree.items() if degree == 0]
result = []
step_map = {step.id: step for step in self.task.steps}
while queue:
queue.sort()
current = queue.pop(0)
result.append(step_map[current])
for dependent_id in self._adjacency[current]:
in_degree[dependent_id] -= 1
if in_degree[dependent_id] == 0:
queue.append(dependent_id)
return result
def get_ready_steps(self, completed_step_ids: List[str]) -> List[Step]:
"""Get steps that are ready to execute"""
step_map = {step.id: step for step in self.task.steps}
ready = []
for step in self.task.steps:
if step.id in completed_step_ids:
continue
if step.status != StepStatus.PENDING:
continue
deps_completed = all(dep_id in completed_step_ids for dep_id in step.depends_on)
if deps_completed:
ready.append(step)
return ready
def detect_cycles(self) -> List[List[str]]:
"""Detect cycles in the graph"""
WHITE, GRAY, BLACK = 0, 1, 2
color = {step.id: WHITE for step in self.task.steps}
cycles = []
def dfs(node: str, path: List[str]) -> bool:
color[node] = GRAY
path.append(node)
for neighbor in self._adjacency.get(node, []):
if color[neighbor] == GRAY:
cycle_start = path.index(neighbor)
cycles.append(path[cycle_start:] + [neighbor])
return True
elif color[neighbor] == WHITE:
if dfs(neighbor, path):
return True
path.pop()
color[node] = BLACK
return False
for step in self.task.steps:
if color[step.id] == WHITE:
dfs(step.id, [])
return cycles
def validate(self) -> tuple[bool, Optional[str]]:
"""Validate the graph structure"""
cycles = self.detect_cycles()
if cycles:
return False, f"Circular dependency detected: {cycles[0]}"
step_ids = {step.id for step in self.task.steps}
for step in self.task.steps:
for dep_id in step.depends_on:
if dep_id not in step_ids:
return False, f"Step '{step.name}' depends on non-existent step '{dep_id}'"
return True, None
class TaskService:
"""Task service for managing tasks"""
def __init__(self):
self._tasks: Dict[str, Task] = {}
def create_task(
self,
name: str,
goal: str,
description: str = "",
steps: List[Dict[str, Any]] = None
) -> Task:
"""Create a new task"""
task_id = generate_id("task")
task = Task(
id=task_id,
name=name,
description=description,
goal=goal
)
if steps:
for step_data in steps:
step = Step(
id=generate_id("step"),
name=step_data.get("name", ""),
description=step_data.get("description", "")
)
task.steps.append(step)
self._tasks[task_id] = task
logger.info(f"Created task: {task_id}")
return task
def get_task(self, task_id: str) -> Optional[Task]:
"""Get task by ID"""
return self._tasks.get(task_id)
def list_tasks(self) -> List[Task]:
"""List all tasks"""
return list(self._tasks.values())
def update_task_status(
self,
task_id: str,
status: TaskStatus,
result: Any = None
) -> Optional[Task]:
"""Update task status"""
task = self._tasks.get(task_id)
if not task:
return None
task.status = status
task.result = result
task.updated_at = datetime.now()
return task
def add_steps(
self,
task_id: str,
steps: List[Dict[str, Any]]
) -> Optional[List[Step]]:
"""Add steps to task"""
task = self._tasks.get(task_id)
if not task:
return None
result = []
for step_data in steps:
step = Step(
id=generate_id("step"),
name=step_data.get("name", ""),
description=step_data.get("description", ""),
depends_on=step_data.get("depends_on", [])
)
task.steps.append(step)
result.append(step)
task.updated_at = datetime.now()
return result
def delete_task(self, task_id: str) -> bool:
"""Delete task"""
if task_id not in self._tasks:
return False
del self._tasks[task_id]
return True
def build_graph(self, task_id: str) -> Optional[TaskGraph]:
"""Build task graph for a task"""
task = self._tasks.get(task_id)
if not task:
return None
return TaskGraph(task)
task_service = TaskService()