Compare commits

...

2 Commits

Author SHA1 Message Date
ViperEkura dc08267c15 feat: 增加task实现 2026-04-17 23:05:54 +08:00
ViperEkura f10909bec3 chore: 优化chat部分结构 2026-04-17 23:01:48 +08:00
2 changed files with 489 additions and 145 deletions

View File

@ -45,6 +45,157 @@ def get_llm_client(conversation: Conversation = None):
return client, max_tokens
class StreamContext:
"""Context for streaming response state management."""
def __init__(
self,
step_index: int = 0,
current_step_id: str = None,
current_step_idx: int = None,
current_stream_type: str = None,
full_content: str = "",
full_thinking: str = ""
):
self.step_index = step_index
self.current_step_id = current_step_id
self.current_step_idx = current_step_idx
self.current_stream_type = current_stream_type
self.full_content = full_content
self.full_thinking = full_thinking
self.all_steps = []
self.all_tool_calls = []
self.all_tool_results = []
self.tool_calls_list = []
def reset_iteration(self):
"""Reset streaming step tracker for new iteration."""
self.current_step_id = None
self.current_step_idx = None
self.current_stream_type = None
self.full_content = ""
self.full_thinking = ""
self.tool_calls_list = []
def start_stream_step(self, step_type: str) -> str:
"""Start a new streaming step. Returns the step_id."""
self.current_step_idx = self.step_index
self.current_step_id = f"step-{self.step_index}"
self.current_stream_type = step_type
self.step_index += 1
return self.current_step_id
def yield_stream_step(self, step_type: str, content: str) -> Dict[str, Any]:
"""Yield a streaming step event."""
return _sse_event("process_step", {
"step": {
"id": self.current_step_id,
"index": self.current_step_idx,
"type": step_type,
"content": content
}
})
def save_streaming_step(self):
"""Save the current streaming step to all_steps."""
if self.current_step_id is None:
return
if self.current_stream_type == "thinking":
self.all_steps.append({
"id": self.current_step_id,
"index": self.current_step_idx,
"type": "thinking",
"content": self.full_thinking
})
elif self.current_stream_type == "text":
self.all_steps.append({
"id": self.current_step_id,
"index": self.current_step_idx,
"type": "text",
"content": self.full_content
})
def handle_thinking_stream(self, delta: Dict) -> Optional[Dict]:
"""Handle reasoning/thinking delta. Returns yield_obj if step was yielded."""
reasoning = delta.get("reasoning_content", "")
if not reasoning:
return None
prev_len = len(self.full_thinking)
self.full_thinking += reasoning
if prev_len == 0: # New thinking stream started
self.start_stream_step("thinking")
return self.yield_stream_step("thinking", self.full_thinking)
def handle_text_stream(self, delta: Dict) -> Optional[Dict]:
"""Handle content delta. Returns yield_obj if step was yielded."""
content = delta.get("content", "")
if not content:
return None
prev_len = len(self.full_content)
self.full_content += content
if prev_len == 0: # New text stream started
self.start_stream_step("text")
return self.yield_stream_step("text", self.full_content)
def handle_tool_call(self) -> tuple:
"""Handle tool calls. Returns (tool_call_step_ids, tool_call_steps, yield_objs)."""
tool_call_step_ids = []
tool_call_steps = []
yield_objs = []
for tc in self.tool_calls_list:
call_step_idx = self.step_index
call_step_id = f"step-{self.step_index}"
tool_call_step_ids.append(call_step_id)
self.step_index += 1
call_step = {
"id": call_step_id,
"index": call_step_idx,
"type": "tool_call",
"id_ref": tc.get("id", ""),
"name": tc["function"]["name"],
"arguments": tc["function"]["arguments"]
}
tool_call_steps.append(call_step)
yield_objs.append(_sse_event("process_step", {"step": call_step}))
return tool_call_step_ids, tool_call_steps, yield_objs
def handle_tool_result(self, tool_result: Dict, tool_call_step_id: str) -> tuple:
"""Handle single tool result. Returns (result_step, yield_obj)."""
result_step_idx = self.step_index
result_step_id = f"step-{self.step_index}"
self.step_index += 1
content = tool_result.get("content", "")
success = True
try:
content_obj = json.loads(content)
if isinstance(content_obj, dict):
success = content_obj.get("success", True)
except:
pass
result_step = {
"id": result_step_id,
"index": result_step_idx,
"type": "tool_result",
"id_ref": tool_call_step_id,
"name": tool_result.get("name", ""),
"content": content,
"success": success
}
return result_step, _sse_event("process_step", {"step": result_step})
class ChatService:
"""Chat service with tool support"""
@ -129,12 +280,6 @@ class ChatService:
# 直接使用 provider 的 max_tokens
max_tokens = provider_max_tokens
# State tracking
all_steps = []
all_tool_calls = []
all_tool_results = []
step_index = 0
# Token usage tracking
total_usage = {
"prompt_tokens": 0,
@ -142,23 +287,12 @@ class ChatService:
"total_tokens": 0
}
# Global step IDs for thinking and text (persist across iterations)
thinking_step_id = None
thinking_step_idx = None
text_step_id = None
text_step_idx = None
# Streaming context for state management
ctx = StreamContext()
for iteration in range(MAX_ITERATIONS):
# Stream from LLM
full_content = ""
full_thinking = ""
tool_calls_list = []
# Step tracking - use unified step-{index} format
thinking_step_id = None
thinking_step_idx = None
text_step_id = None
text_step_idx = None
# Reset streaming context for this iteration
ctx.reset_iteration()
async for sse_line in llm.stream_call(
model=model,
@ -218,19 +352,16 @@ class ChatService:
if chunk.get("content") or chunk.get("message"):
content = chunk.get("content") or chunk.get("message", {}).get("content", "")
if content:
# BUG FIX: Update full_content so it gets saved to database
prev_content_len = len(full_content)
full_content += content
if prev_content_len == 0: # New text stream started
text_step_idx = step_index
text_step_id = f"step-{step_index}"
step_index += 1
prev_len = len(ctx.full_content)
ctx.full_content += content
if prev_len == 0: # New text stream started
ctx.start_stream_step("text")
yield _sse_event("process_step", {
"step": {
"id": text_step_id if prev_content_len == 0 else f"step-{step_index - 1}",
"index": text_step_idx if prev_content_len == 0 else step_index - 1,
"id": ctx.current_step_id if prev_len == 0 else f"step-{ctx.step_index - 1}",
"index": ctx.current_step_idx if prev_len == 0 else ctx.step_index - 1,
"type": "text",
"content": full_content # Always send accumulated content
"content": ctx.full_content
}
})
continue
@ -238,96 +369,43 @@ class ChatService:
delta = choices[0].get("delta", {})
# Handle reasoning (thinking)
reasoning = delta.get("reasoning_content", "")
if reasoning:
prev_thinking_len = len(full_thinking)
full_thinking += reasoning
if prev_thinking_len == 0: # New thinking stream started
thinking_step_idx = step_index
thinking_step_id = f"step-{step_index}"
step_index += 1
yield _sse_event("process_step", {
"step": {
"id": thinking_step_id,
"index": thinking_step_idx,
"type": "thinking",
"content": full_thinking
}
})
yield_obj = ctx.handle_thinking_stream(delta)
if yield_obj:
yield yield_obj
# Handle content
content = delta.get("content", "")
if content:
prev_content_len = len(full_content)
full_content += content
if prev_content_len == 0: # New text stream started
text_step_idx = step_index
text_step_id = f"step-{step_index}"
step_index += 1
yield _sse_event("process_step", {
"step": {
"id": text_step_id,
"index": text_step_idx,
"type": "text",
"content": full_content
}
})
yield_obj = ctx.handle_text_stream(delta)
if yield_obj:
yield yield_obj
# Accumulate tool calls
tool_calls_delta = delta.get("tool_calls", [])
for tc in tool_calls_delta:
idx = tc.get("index", 0)
if idx >= len(tool_calls_list):
tool_calls_list.append({
if idx >= len(ctx.tool_calls_list):
ctx.tool_calls_list.append({
"id": tc.get("id", ""),
"type": "function",
"function": {"name": "", "arguments": ""}
})
func = tc.get("function", {})
if func.get("name"):
tool_calls_list[idx]["function"]["name"] += func["name"]
ctx.tool_calls_list[idx]["function"]["name"] += func["name"]
if func.get("arguments"):
tool_calls_list[idx]["function"]["arguments"] += func["arguments"]
ctx.tool_calls_list[idx]["function"]["arguments"] += func["arguments"]
# Save thinking step
if thinking_step_id is not None:
all_steps.append({
"id": thinking_step_id,
"index": thinking_step_idx,
"type": "thinking",
"content": full_thinking
})
# Save text step
if text_step_id is not None:
all_steps.append({
"id": text_step_id,
"index": text_step_idx,
"type": "text",
"content": full_content
})
# Save streaming step (thinking or text)
ctx.save_streaming_step()
# Handle tool calls
if tool_calls_list:
all_tool_calls.extend(tool_calls_list)
if ctx.tool_calls_list:
ctx.all_tool_calls.extend(ctx.tool_calls_list)
# Yield tool_call steps - use unified step-{index} format
tool_call_step_ids = [] # Track step IDs for tool calls
for tc in tool_calls_list:
call_step_idx = step_index
call_step_id = f"step-{step_index}"
tool_call_step_ids.append(call_step_id)
step_index += 1
call_step = {
"id": call_step_id,
"index": call_step_idx,
"type": "tool_call",
"id_ref": tc.get("id", ""),
"name": tc["function"]["name"],
"arguments": tc["function"]["arguments"]
}
all_steps.append(call_step)
yield _sse_event("process_step", {"step": call_step})
# Handle tool_call steps
tool_call_step_ids, tool_call_steps, yield_objs = ctx.handle_tool_call()
ctx.all_steps.extend(tool_call_steps)
for yield_obj in yield_objs:
yield yield_obj
# Execute tools
tool_context = {
@ -337,39 +415,17 @@ class ChatService:
"user_permission_level": user_permission_level
}
tool_results = self.tool_executor.process_tool_calls_parallel(
tool_calls_list, tool_context
ctx.tool_calls_list, tool_context
)
# Yield tool_result steps - use unified step-{index} format
# Handle tool_result steps
for i, tr in enumerate(tool_results):
tool_call_step_id = tool_call_step_ids[i] if i < len(tool_call_step_ids) else f"step-{i}"
result_step_idx = step_index
result_step_id = f"step-{step_index}"
step_index += 1
result_step, yield_obj = ctx.handle_tool_result(tr, tool_call_step_id)
ctx.all_steps.append(result_step)
yield yield_obj
# 解析 content 中的 success 状态
content = tr.get("content", "")
success = True
try:
content_obj = json.loads(content)
if isinstance(content_obj, dict):
success = content_obj.get("success", True)
except:
pass
result_step = {
"id": result_step_id,
"index": result_step_idx,
"type": "tool_result",
"id_ref": tool_call_step_id, # Reference to the tool_call step
"name": tr.get("name", ""),
"content": content,
"success": success
}
all_steps.append(result_step)
yield _sse_event("process_step", {"step": result_step})
all_tool_results.append({
ctx.all_tool_results.append({
"role": "tool",
"tool_call_id": tr.get("tool_call_id", ""),
"content": tr.get("content", "")
@ -378,27 +434,27 @@ class ChatService:
# Add assistant message with tool calls for next iteration
messages.append({
"role": "assistant",
"content": full_content or "",
"tool_calls": tool_calls_list
"content": ctx.full_content or "",
"tool_calls": ctx.tool_calls_list
})
messages.extend(all_tool_results[-len(tool_results):])
all_tool_results = []
messages.extend(ctx.all_tool_results[-len(tool_results):])
ctx.all_tool_results = []
continue
# No tool calls - final iteration, save message
msg_id = str(uuid.uuid4())
# 使用 API 返回的真实 completion_tokens如果 API 没返回则降级使用估算值
actual_token_count = total_usage.get("completion_tokens", 0) or len(full_content) // 4
actual_token_count = total_usage.get("completion_tokens", 0) or len(ctx.full_content) // 4
logger.info(f"[TOKEN] total_usage: {total_usage}, actual_token_count: {actual_token_count}")
self._save_message(
conversation.id,
msg_id,
full_content,
all_tool_calls,
all_tool_results,
all_steps,
ctx.full_content,
ctx.all_tool_calls,
ctx.all_tool_results,
ctx.all_steps,
actual_token_count,
total_usage
)
@ -411,15 +467,15 @@ class ChatService:
return
# Max iterations exceeded - save message before error
if full_content or all_tool_calls:
if ctx.full_content or ctx.all_tool_calls:
msg_id = str(uuid.uuid4())
self._save_message(
conversation.id,
msg_id,
full_content,
all_tool_calls,
all_tool_results,
all_steps,
ctx.full_content,
ctx.all_tool_calls,
ctx.all_tool_results,
ctx.all_steps,
actual_token_count,
total_usage
)

288
luxx/services/task.py Normal file
View File

@ -0,0 +1,288 @@
"""Task module for autonomous task execution"""
from dataclasses import dataclass, field
from datetime import datetime
from enum import Enum
import logging
from typing import List, Optional, Dict, Any
from luxx.utils.helpers import generate_id
logger = logging.getLogger(__name__)
class TaskStatus(Enum):
"""Task status enum"""
PENDING = "pending"
READY = "ready"
RUNNING = "running"
BLOCK = "block"
TERMINATED = "terminated"
class StepStatus(Enum):
"""Step status enum"""
PENDING = "pending"
RUNNING = "running"
COMPLETED = "completed"
FAILED = "failed"
SKIPPED = "skipped"
@dataclass
class Step:
"""Task step"""
id: str
name: str
description: str = ""
depends_on: List[str] = field(default_factory=list)
status: StepStatus = StepStatus.PENDING
result: Optional[Dict[str, Any]] = None
created_at: datetime = field(default_factory=datetime.now)
updated_at: datetime = field(default_factory=datetime.now)
def to_dict(self) -> Dict[str, Any]:
"""Convert to dictionary"""
return {
"id": self.id,
"name": self.name,
"description": self.description,
"depends_on": self.depends_on,
"status": self.status.value,
"result": self.result,
"created_at": self.created_at.isoformat() if self.created_at else None,
"updated_at": self.updated_at.isoformat() if self.updated_at else None
}
@dataclass
class Task:
"""Task entity"""
id: str
name: str
description: str = ""
goal: str = ""
status: TaskStatus = TaskStatus.PENDING
steps: List[Step] = field(default_factory=list)
subtasks: List["Task"] = field(default_factory=list)
result: Optional[Dict[str, Any]] = None
created_at: datetime = field(default_factory=datetime.now)
updated_at: datetime = field(default_factory=datetime.now)
def to_dict(self) -> Dict[str, Any]:
"""Convert to dictionary"""
return {
"id": self.id,
"name": self.name,
"description": self.description,
"goal": self.goal,
"status": self.status.value,
"steps": [s.to_dict() for s in self.steps],
"subtasks": [t.to_dict() for t in self.subtasks],
"result": self.result,
"created_at": self.created_at.isoformat() if self.created_at else None,
"updated_at": self.updated_at.isoformat() if self.updated_at else None
}
class TaskGraph:
"""Task graph for managing step dependencies"""
def __init__(self, task: Task):
self.task = task
self._adjacency: Dict[str, List[str]] = {}
self._reverse_adjacency: Dict[str, List[str]] = {}
self._in_degree: Dict[str, int] = {}
self._build_graph()
def _build_graph(self) -> None:
"""Build graph from task steps"""
for step in self.task.steps:
self._adjacency[step.id] = []
self._reverse_adjacency[step.id] = []
self._in_degree[step.id] = 0
for step in self.task.steps:
for dep_id in step.depends_on:
if dep_id in self._adjacency:
self._adjacency[dep_id].append(step.id)
self._reverse_adjacency[step.id].append(dep_id)
self._in_degree[step.id] += 1
def topological_sort(self) -> List[Step]:
"""Get steps in topological order"""
in_degree = self._in_degree.copy()
queue = [step_id for step_id, degree in in_degree.items() if degree == 0]
result = []
step_map = {step.id: step for step in self.task.steps}
while queue:
queue.sort()
current = queue.pop(0)
result.append(step_map[current])
for dependent_id in self._adjacency[current]:
in_degree[dependent_id] -= 1
if in_degree[dependent_id] == 0:
queue.append(dependent_id)
return result
def get_ready_steps(self, completed_step_ids: List[str]) -> List[Step]:
"""Get steps that are ready to execute"""
step_map = {step.id: step for step in self.task.steps}
ready = []
for step in self.task.steps:
if step.id in completed_step_ids:
continue
if step.status != StepStatus.PENDING:
continue
deps_completed = all(dep_id in completed_step_ids for dep_id in step.depends_on)
if deps_completed:
ready.append(step)
return ready
def detect_cycles(self) -> List[List[str]]:
"""Detect cycles in the graph"""
WHITE, GRAY, BLACK = 0, 1, 2
color = {step.id: WHITE for step in self.task.steps}
cycles = []
def dfs(node: str, path: List[str]) -> bool:
color[node] = GRAY
path.append(node)
for neighbor in self._adjacency.get(node, []):
if color[neighbor] == GRAY:
cycle_start = path.index(neighbor)
cycles.append(path[cycle_start:] + [neighbor])
return True
elif color[neighbor] == WHITE:
if dfs(neighbor, path):
return True
path.pop()
color[node] = BLACK
return False
for step in self.task.steps:
if color[step.id] == WHITE:
dfs(step.id, [])
return cycles
def validate(self) -> tuple[bool, Optional[str]]:
"""Validate the graph structure"""
cycles = self.detect_cycles()
if cycles:
return False, f"Circular dependency detected: {cycles[0]}"
step_ids = {step.id for step in self.task.steps}
for step in self.task.steps:
for dep_id in step.depends_on:
if dep_id not in step_ids:
return False, f"Step '{step.name}' depends on non-existent step '{dep_id}'"
return True, None
class TaskService:
"""Task service for managing tasks"""
def __init__(self):
self._tasks: Dict[str, Task] = {}
def create_task(
self,
name: str,
goal: str,
description: str = "",
steps: List[Dict[str, Any]] = None
) -> Task:
"""Create a new task"""
task_id = generate_id("task")
task = Task(
id=task_id,
name=name,
description=description,
goal=goal
)
if steps:
for step_data in steps:
step = Step(
id=generate_id("step"),
name=step_data.get("name", ""),
description=step_data.get("description", "")
)
task.steps.append(step)
self._tasks[task_id] = task
logger.info(f"Created task: {task_id}")
return task
def get_task(self, task_id: str) -> Optional[Task]:
"""Get task by ID"""
return self._tasks.get(task_id)
def list_tasks(self) -> List[Task]:
"""List all tasks"""
return list(self._tasks.values())
def update_task_status(
self,
task_id: str,
status: TaskStatus,
result: Any = None
) -> Optional[Task]:
"""Update task status"""
task = self._tasks.get(task_id)
if not task:
return None
task.status = status
task.result = result
task.updated_at = datetime.now()
return task
def add_steps(
self,
task_id: str,
steps: List[Dict[str, Any]]
) -> Optional[List[Step]]:
"""Add steps to task"""
task = self._tasks.get(task_id)
if not task:
return None
result = []
for step_data in steps:
step = Step(
id=generate_id("step"),
name=step_data.get("name", ""),
description=step_data.get("description", ""),
depends_on=step_data.get("depends_on", [])
)
task.steps.append(step)
result.append(step)
task.updated_at = datetime.now()
return result
def delete_task(self, task_id: str) -> bool:
"""Delete task"""
if task_id not in self._tasks:
return False
del self._tasks[task_id]
return True
def build_graph(self, task_id: str) -> Optional[TaskGraph]:
"""Build task graph for a task"""
task = self._tasks.get(task_id)
if not task:
return None
return TaskGraph(task)
task_service = TaskService()