Compare commits
2 Commits
4de03866f4
...
dc08267c15
| Author | SHA1 | Date |
|---|---|---|
|
|
dc08267c15 | |
|
|
f10909bec3 |
|
|
@ -45,6 +45,157 @@ def get_llm_client(conversation: Conversation = None):
|
||||||
return client, max_tokens
|
return client, max_tokens
|
||||||
|
|
||||||
|
|
||||||
|
class StreamContext:
|
||||||
|
"""Context for streaming response state management."""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
step_index: int = 0,
|
||||||
|
current_step_id: str = None,
|
||||||
|
current_step_idx: int = None,
|
||||||
|
current_stream_type: str = None,
|
||||||
|
full_content: str = "",
|
||||||
|
full_thinking: str = ""
|
||||||
|
):
|
||||||
|
self.step_index = step_index
|
||||||
|
self.current_step_id = current_step_id
|
||||||
|
self.current_step_idx = current_step_idx
|
||||||
|
self.current_stream_type = current_stream_type
|
||||||
|
self.full_content = full_content
|
||||||
|
self.full_thinking = full_thinking
|
||||||
|
self.all_steps = []
|
||||||
|
self.all_tool_calls = []
|
||||||
|
self.all_tool_results = []
|
||||||
|
self.tool_calls_list = []
|
||||||
|
|
||||||
|
def reset_iteration(self):
|
||||||
|
"""Reset streaming step tracker for new iteration."""
|
||||||
|
self.current_step_id = None
|
||||||
|
self.current_step_idx = None
|
||||||
|
self.current_stream_type = None
|
||||||
|
self.full_content = ""
|
||||||
|
self.full_thinking = ""
|
||||||
|
self.tool_calls_list = []
|
||||||
|
|
||||||
|
def start_stream_step(self, step_type: str) -> str:
|
||||||
|
"""Start a new streaming step. Returns the step_id."""
|
||||||
|
self.current_step_idx = self.step_index
|
||||||
|
self.current_step_id = f"step-{self.step_index}"
|
||||||
|
self.current_stream_type = step_type
|
||||||
|
self.step_index += 1
|
||||||
|
return self.current_step_id
|
||||||
|
|
||||||
|
def yield_stream_step(self, step_type: str, content: str) -> Dict[str, Any]:
|
||||||
|
"""Yield a streaming step event."""
|
||||||
|
return _sse_event("process_step", {
|
||||||
|
"step": {
|
||||||
|
"id": self.current_step_id,
|
||||||
|
"index": self.current_step_idx,
|
||||||
|
"type": step_type,
|
||||||
|
"content": content
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
def save_streaming_step(self):
|
||||||
|
"""Save the current streaming step to all_steps."""
|
||||||
|
if self.current_step_id is None:
|
||||||
|
return
|
||||||
|
|
||||||
|
if self.current_stream_type == "thinking":
|
||||||
|
self.all_steps.append({
|
||||||
|
"id": self.current_step_id,
|
||||||
|
"index": self.current_step_idx,
|
||||||
|
"type": "thinking",
|
||||||
|
"content": self.full_thinking
|
||||||
|
})
|
||||||
|
elif self.current_stream_type == "text":
|
||||||
|
self.all_steps.append({
|
||||||
|
"id": self.current_step_id,
|
||||||
|
"index": self.current_step_idx,
|
||||||
|
"type": "text",
|
||||||
|
"content": self.full_content
|
||||||
|
})
|
||||||
|
|
||||||
|
def handle_thinking_stream(self, delta: Dict) -> Optional[Dict]:
|
||||||
|
"""Handle reasoning/thinking delta. Returns yield_obj if step was yielded."""
|
||||||
|
reasoning = delta.get("reasoning_content", "")
|
||||||
|
if not reasoning:
|
||||||
|
return None
|
||||||
|
|
||||||
|
prev_len = len(self.full_thinking)
|
||||||
|
self.full_thinking += reasoning
|
||||||
|
|
||||||
|
if prev_len == 0: # New thinking stream started
|
||||||
|
self.start_stream_step("thinking")
|
||||||
|
|
||||||
|
return self.yield_stream_step("thinking", self.full_thinking)
|
||||||
|
|
||||||
|
def handle_text_stream(self, delta: Dict) -> Optional[Dict]:
|
||||||
|
"""Handle content delta. Returns yield_obj if step was yielded."""
|
||||||
|
content = delta.get("content", "")
|
||||||
|
if not content:
|
||||||
|
return None
|
||||||
|
|
||||||
|
prev_len = len(self.full_content)
|
||||||
|
self.full_content += content
|
||||||
|
|
||||||
|
if prev_len == 0: # New text stream started
|
||||||
|
self.start_stream_step("text")
|
||||||
|
|
||||||
|
return self.yield_stream_step("text", self.full_content)
|
||||||
|
|
||||||
|
def handle_tool_call(self) -> tuple:
|
||||||
|
"""Handle tool calls. Returns (tool_call_step_ids, tool_call_steps, yield_objs)."""
|
||||||
|
tool_call_step_ids = []
|
||||||
|
tool_call_steps = []
|
||||||
|
yield_objs = []
|
||||||
|
|
||||||
|
for tc in self.tool_calls_list:
|
||||||
|
call_step_idx = self.step_index
|
||||||
|
call_step_id = f"step-{self.step_index}"
|
||||||
|
tool_call_step_ids.append(call_step_id)
|
||||||
|
self.step_index += 1
|
||||||
|
|
||||||
|
call_step = {
|
||||||
|
"id": call_step_id,
|
||||||
|
"index": call_step_idx,
|
||||||
|
"type": "tool_call",
|
||||||
|
"id_ref": tc.get("id", ""),
|
||||||
|
"name": tc["function"]["name"],
|
||||||
|
"arguments": tc["function"]["arguments"]
|
||||||
|
}
|
||||||
|
tool_call_steps.append(call_step)
|
||||||
|
yield_objs.append(_sse_event("process_step", {"step": call_step}))
|
||||||
|
|
||||||
|
return tool_call_step_ids, tool_call_steps, yield_objs
|
||||||
|
|
||||||
|
def handle_tool_result(self, tool_result: Dict, tool_call_step_id: str) -> tuple:
|
||||||
|
"""Handle single tool result. Returns (result_step, yield_obj)."""
|
||||||
|
result_step_idx = self.step_index
|
||||||
|
result_step_id = f"step-{self.step_index}"
|
||||||
|
self.step_index += 1
|
||||||
|
|
||||||
|
content = tool_result.get("content", "")
|
||||||
|
success = True
|
||||||
|
try:
|
||||||
|
content_obj = json.loads(content)
|
||||||
|
if isinstance(content_obj, dict):
|
||||||
|
success = content_obj.get("success", True)
|
||||||
|
except:
|
||||||
|
pass
|
||||||
|
|
||||||
|
result_step = {
|
||||||
|
"id": result_step_id,
|
||||||
|
"index": result_step_idx,
|
||||||
|
"type": "tool_result",
|
||||||
|
"id_ref": tool_call_step_id,
|
||||||
|
"name": tool_result.get("name", ""),
|
||||||
|
"content": content,
|
||||||
|
"success": success
|
||||||
|
}
|
||||||
|
return result_step, _sse_event("process_step", {"step": result_step})
|
||||||
|
|
||||||
|
|
||||||
class ChatService:
|
class ChatService:
|
||||||
"""Chat service with tool support"""
|
"""Chat service with tool support"""
|
||||||
|
|
||||||
|
|
@ -129,12 +280,6 @@ class ChatService:
|
||||||
# 直接使用 provider 的 max_tokens
|
# 直接使用 provider 的 max_tokens
|
||||||
max_tokens = provider_max_tokens
|
max_tokens = provider_max_tokens
|
||||||
|
|
||||||
# State tracking
|
|
||||||
all_steps = []
|
|
||||||
all_tool_calls = []
|
|
||||||
all_tool_results = []
|
|
||||||
step_index = 0
|
|
||||||
|
|
||||||
# Token usage tracking
|
# Token usage tracking
|
||||||
total_usage = {
|
total_usage = {
|
||||||
"prompt_tokens": 0,
|
"prompt_tokens": 0,
|
||||||
|
|
@ -142,23 +287,12 @@ class ChatService:
|
||||||
"total_tokens": 0
|
"total_tokens": 0
|
||||||
}
|
}
|
||||||
|
|
||||||
# Global step IDs for thinking and text (persist across iterations)
|
# Streaming context for state management
|
||||||
thinking_step_id = None
|
ctx = StreamContext()
|
||||||
thinking_step_idx = None
|
|
||||||
text_step_id = None
|
|
||||||
text_step_idx = None
|
|
||||||
|
|
||||||
for iteration in range(MAX_ITERATIONS):
|
for iteration in range(MAX_ITERATIONS):
|
||||||
# Stream from LLM
|
# Reset streaming context for this iteration
|
||||||
full_content = ""
|
ctx.reset_iteration()
|
||||||
full_thinking = ""
|
|
||||||
tool_calls_list = []
|
|
||||||
|
|
||||||
# Step tracking - use unified step-{index} format
|
|
||||||
thinking_step_id = None
|
|
||||||
thinking_step_idx = None
|
|
||||||
text_step_id = None
|
|
||||||
text_step_idx = None
|
|
||||||
|
|
||||||
async for sse_line in llm.stream_call(
|
async for sse_line in llm.stream_call(
|
||||||
model=model,
|
model=model,
|
||||||
|
|
@ -218,19 +352,16 @@ class ChatService:
|
||||||
if chunk.get("content") or chunk.get("message"):
|
if chunk.get("content") or chunk.get("message"):
|
||||||
content = chunk.get("content") or chunk.get("message", {}).get("content", "")
|
content = chunk.get("content") or chunk.get("message", {}).get("content", "")
|
||||||
if content:
|
if content:
|
||||||
# BUG FIX: Update full_content so it gets saved to database
|
prev_len = len(ctx.full_content)
|
||||||
prev_content_len = len(full_content)
|
ctx.full_content += content
|
||||||
full_content += content
|
if prev_len == 0: # New text stream started
|
||||||
if prev_content_len == 0: # New text stream started
|
ctx.start_stream_step("text")
|
||||||
text_step_idx = step_index
|
|
||||||
text_step_id = f"step-{step_index}"
|
|
||||||
step_index += 1
|
|
||||||
yield _sse_event("process_step", {
|
yield _sse_event("process_step", {
|
||||||
"step": {
|
"step": {
|
||||||
"id": text_step_id if prev_content_len == 0 else f"step-{step_index - 1}",
|
"id": ctx.current_step_id if prev_len == 0 else f"step-{ctx.step_index - 1}",
|
||||||
"index": text_step_idx if prev_content_len == 0 else step_index - 1,
|
"index": ctx.current_step_idx if prev_len == 0 else ctx.step_index - 1,
|
||||||
"type": "text",
|
"type": "text",
|
||||||
"content": full_content # Always send accumulated content
|
"content": ctx.full_content
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
continue
|
continue
|
||||||
|
|
@ -238,96 +369,43 @@ class ChatService:
|
||||||
delta = choices[0].get("delta", {})
|
delta = choices[0].get("delta", {})
|
||||||
|
|
||||||
# Handle reasoning (thinking)
|
# Handle reasoning (thinking)
|
||||||
reasoning = delta.get("reasoning_content", "")
|
yield_obj = ctx.handle_thinking_stream(delta)
|
||||||
if reasoning:
|
if yield_obj:
|
||||||
prev_thinking_len = len(full_thinking)
|
yield yield_obj
|
||||||
full_thinking += reasoning
|
|
||||||
if prev_thinking_len == 0: # New thinking stream started
|
|
||||||
thinking_step_idx = step_index
|
|
||||||
thinking_step_id = f"step-{step_index}"
|
|
||||||
step_index += 1
|
|
||||||
yield _sse_event("process_step", {
|
|
||||||
"step": {
|
|
||||||
"id": thinking_step_id,
|
|
||||||
"index": thinking_step_idx,
|
|
||||||
"type": "thinking",
|
|
||||||
"content": full_thinking
|
|
||||||
}
|
|
||||||
})
|
|
||||||
|
|
||||||
# Handle content
|
# Handle content
|
||||||
content = delta.get("content", "")
|
yield_obj = ctx.handle_text_stream(delta)
|
||||||
if content:
|
if yield_obj:
|
||||||
prev_content_len = len(full_content)
|
yield yield_obj
|
||||||
full_content += content
|
|
||||||
if prev_content_len == 0: # New text stream started
|
|
||||||
text_step_idx = step_index
|
|
||||||
text_step_id = f"step-{step_index}"
|
|
||||||
step_index += 1
|
|
||||||
yield _sse_event("process_step", {
|
|
||||||
"step": {
|
|
||||||
"id": text_step_id,
|
|
||||||
"index": text_step_idx,
|
|
||||||
"type": "text",
|
|
||||||
"content": full_content
|
|
||||||
}
|
|
||||||
})
|
|
||||||
|
|
||||||
# Accumulate tool calls
|
# Accumulate tool calls
|
||||||
tool_calls_delta = delta.get("tool_calls", [])
|
tool_calls_delta = delta.get("tool_calls", [])
|
||||||
for tc in tool_calls_delta:
|
for tc in tool_calls_delta:
|
||||||
idx = tc.get("index", 0)
|
idx = tc.get("index", 0)
|
||||||
if idx >= len(tool_calls_list):
|
if idx >= len(ctx.tool_calls_list):
|
||||||
tool_calls_list.append({
|
ctx.tool_calls_list.append({
|
||||||
"id": tc.get("id", ""),
|
"id": tc.get("id", ""),
|
||||||
"type": "function",
|
"type": "function",
|
||||||
"function": {"name": "", "arguments": ""}
|
"function": {"name": "", "arguments": ""}
|
||||||
})
|
})
|
||||||
func = tc.get("function", {})
|
func = tc.get("function", {})
|
||||||
if func.get("name"):
|
if func.get("name"):
|
||||||
tool_calls_list[idx]["function"]["name"] += func["name"]
|
ctx.tool_calls_list[idx]["function"]["name"] += func["name"]
|
||||||
if func.get("arguments"):
|
if func.get("arguments"):
|
||||||
tool_calls_list[idx]["function"]["arguments"] += func["arguments"]
|
ctx.tool_calls_list[idx]["function"]["arguments"] += func["arguments"]
|
||||||
|
|
||||||
# Save thinking step
|
# Save streaming step (thinking or text)
|
||||||
if thinking_step_id is not None:
|
ctx.save_streaming_step()
|
||||||
all_steps.append({
|
|
||||||
"id": thinking_step_id,
|
|
||||||
"index": thinking_step_idx,
|
|
||||||
"type": "thinking",
|
|
||||||
"content": full_thinking
|
|
||||||
})
|
|
||||||
|
|
||||||
# Save text step
|
|
||||||
if text_step_id is not None:
|
|
||||||
all_steps.append({
|
|
||||||
"id": text_step_id,
|
|
||||||
"index": text_step_idx,
|
|
||||||
"type": "text",
|
|
||||||
"content": full_content
|
|
||||||
})
|
|
||||||
|
|
||||||
# Handle tool calls
|
# Handle tool calls
|
||||||
if tool_calls_list:
|
if ctx.tool_calls_list:
|
||||||
all_tool_calls.extend(tool_calls_list)
|
ctx.all_tool_calls.extend(ctx.tool_calls_list)
|
||||||
|
|
||||||
# Yield tool_call steps - use unified step-{index} format
|
# Handle tool_call steps
|
||||||
tool_call_step_ids = [] # Track step IDs for tool calls
|
tool_call_step_ids, tool_call_steps, yield_objs = ctx.handle_tool_call()
|
||||||
for tc in tool_calls_list:
|
ctx.all_steps.extend(tool_call_steps)
|
||||||
call_step_idx = step_index
|
for yield_obj in yield_objs:
|
||||||
call_step_id = f"step-{step_index}"
|
yield yield_obj
|
||||||
tool_call_step_ids.append(call_step_id)
|
|
||||||
step_index += 1
|
|
||||||
call_step = {
|
|
||||||
"id": call_step_id,
|
|
||||||
"index": call_step_idx,
|
|
||||||
"type": "tool_call",
|
|
||||||
"id_ref": tc.get("id", ""),
|
|
||||||
"name": tc["function"]["name"],
|
|
||||||
"arguments": tc["function"]["arguments"]
|
|
||||||
}
|
|
||||||
all_steps.append(call_step)
|
|
||||||
yield _sse_event("process_step", {"step": call_step})
|
|
||||||
|
|
||||||
# Execute tools
|
# Execute tools
|
||||||
tool_context = {
|
tool_context = {
|
||||||
|
|
@ -337,39 +415,17 @@ class ChatService:
|
||||||
"user_permission_level": user_permission_level
|
"user_permission_level": user_permission_level
|
||||||
}
|
}
|
||||||
tool_results = self.tool_executor.process_tool_calls_parallel(
|
tool_results = self.tool_executor.process_tool_calls_parallel(
|
||||||
tool_calls_list, tool_context
|
ctx.tool_calls_list, tool_context
|
||||||
)
|
)
|
||||||
|
|
||||||
# Yield tool_result steps - use unified step-{index} format
|
# Handle tool_result steps
|
||||||
for i, tr in enumerate(tool_results):
|
for i, tr in enumerate(tool_results):
|
||||||
tool_call_step_id = tool_call_step_ids[i] if i < len(tool_call_step_ids) else f"step-{i}"
|
tool_call_step_id = tool_call_step_ids[i] if i < len(tool_call_step_ids) else f"step-{i}"
|
||||||
result_step_idx = step_index
|
result_step, yield_obj = ctx.handle_tool_result(tr, tool_call_step_id)
|
||||||
result_step_id = f"step-{step_index}"
|
ctx.all_steps.append(result_step)
|
||||||
step_index += 1
|
yield yield_obj
|
||||||
|
|
||||||
# 解析 content 中的 success 状态
|
ctx.all_tool_results.append({
|
||||||
content = tr.get("content", "")
|
|
||||||
success = True
|
|
||||||
try:
|
|
||||||
content_obj = json.loads(content)
|
|
||||||
if isinstance(content_obj, dict):
|
|
||||||
success = content_obj.get("success", True)
|
|
||||||
except:
|
|
||||||
pass
|
|
||||||
|
|
||||||
result_step = {
|
|
||||||
"id": result_step_id,
|
|
||||||
"index": result_step_idx,
|
|
||||||
"type": "tool_result",
|
|
||||||
"id_ref": tool_call_step_id, # Reference to the tool_call step
|
|
||||||
"name": tr.get("name", ""),
|
|
||||||
"content": content,
|
|
||||||
"success": success
|
|
||||||
}
|
|
||||||
all_steps.append(result_step)
|
|
||||||
yield _sse_event("process_step", {"step": result_step})
|
|
||||||
|
|
||||||
all_tool_results.append({
|
|
||||||
"role": "tool",
|
"role": "tool",
|
||||||
"tool_call_id": tr.get("tool_call_id", ""),
|
"tool_call_id": tr.get("tool_call_id", ""),
|
||||||
"content": tr.get("content", "")
|
"content": tr.get("content", "")
|
||||||
|
|
@ -378,27 +434,27 @@ class ChatService:
|
||||||
# Add assistant message with tool calls for next iteration
|
# Add assistant message with tool calls for next iteration
|
||||||
messages.append({
|
messages.append({
|
||||||
"role": "assistant",
|
"role": "assistant",
|
||||||
"content": full_content or "",
|
"content": ctx.full_content or "",
|
||||||
"tool_calls": tool_calls_list
|
"tool_calls": ctx.tool_calls_list
|
||||||
})
|
})
|
||||||
messages.extend(all_tool_results[-len(tool_results):])
|
messages.extend(ctx.all_tool_results[-len(tool_results):])
|
||||||
all_tool_results = []
|
ctx.all_tool_results = []
|
||||||
continue
|
continue
|
||||||
|
|
||||||
# No tool calls - final iteration, save message
|
# No tool calls - final iteration, save message
|
||||||
msg_id = str(uuid.uuid4())
|
msg_id = str(uuid.uuid4())
|
||||||
|
|
||||||
# 使用 API 返回的真实 completion_tokens,如果 API 没返回则降级使用估算值
|
# 使用 API 返回的真实 completion_tokens,如果 API 没返回则降级使用估算值
|
||||||
actual_token_count = total_usage.get("completion_tokens", 0) or len(full_content) // 4
|
actual_token_count = total_usage.get("completion_tokens", 0) or len(ctx.full_content) // 4
|
||||||
logger.info(f"[TOKEN] total_usage: {total_usage}, actual_token_count: {actual_token_count}")
|
logger.info(f"[TOKEN] total_usage: {total_usage}, actual_token_count: {actual_token_count}")
|
||||||
|
|
||||||
self._save_message(
|
self._save_message(
|
||||||
conversation.id,
|
conversation.id,
|
||||||
msg_id,
|
msg_id,
|
||||||
full_content,
|
ctx.full_content,
|
||||||
all_tool_calls,
|
ctx.all_tool_calls,
|
||||||
all_tool_results,
|
ctx.all_tool_results,
|
||||||
all_steps,
|
ctx.all_steps,
|
||||||
actual_token_count,
|
actual_token_count,
|
||||||
total_usage
|
total_usage
|
||||||
)
|
)
|
||||||
|
|
@ -411,15 +467,15 @@ class ChatService:
|
||||||
return
|
return
|
||||||
|
|
||||||
# Max iterations exceeded - save message before error
|
# Max iterations exceeded - save message before error
|
||||||
if full_content or all_tool_calls:
|
if ctx.full_content or ctx.all_tool_calls:
|
||||||
msg_id = str(uuid.uuid4())
|
msg_id = str(uuid.uuid4())
|
||||||
self._save_message(
|
self._save_message(
|
||||||
conversation.id,
|
conversation.id,
|
||||||
msg_id,
|
msg_id,
|
||||||
full_content,
|
ctx.full_content,
|
||||||
all_tool_calls,
|
ctx.all_tool_calls,
|
||||||
all_tool_results,
|
ctx.all_tool_results,
|
||||||
all_steps,
|
ctx.all_steps,
|
||||||
actual_token_count,
|
actual_token_count,
|
||||||
total_usage
|
total_usage
|
||||||
)
|
)
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,288 @@
|
||||||
|
"""Task module for autonomous task execution"""
|
||||||
|
from dataclasses import dataclass, field
|
||||||
|
from datetime import datetime
|
||||||
|
from enum import Enum
|
||||||
|
import logging
|
||||||
|
from typing import List, Optional, Dict, Any
|
||||||
|
from luxx.utils.helpers import generate_id
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class TaskStatus(Enum):
|
||||||
|
"""Task status enum"""
|
||||||
|
PENDING = "pending"
|
||||||
|
READY = "ready"
|
||||||
|
RUNNING = "running"
|
||||||
|
BLOCK = "block"
|
||||||
|
TERMINATED = "terminated"
|
||||||
|
|
||||||
|
|
||||||
|
class StepStatus(Enum):
|
||||||
|
"""Step status enum"""
|
||||||
|
PENDING = "pending"
|
||||||
|
RUNNING = "running"
|
||||||
|
COMPLETED = "completed"
|
||||||
|
FAILED = "failed"
|
||||||
|
SKIPPED = "skipped"
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class Step:
|
||||||
|
"""Task step"""
|
||||||
|
id: str
|
||||||
|
name: str
|
||||||
|
description: str = ""
|
||||||
|
depends_on: List[str] = field(default_factory=list)
|
||||||
|
status: StepStatus = StepStatus.PENDING
|
||||||
|
result: Optional[Dict[str, Any]] = None
|
||||||
|
created_at: datetime = field(default_factory=datetime.now)
|
||||||
|
updated_at: datetime = field(default_factory=datetime.now)
|
||||||
|
|
||||||
|
def to_dict(self) -> Dict[str, Any]:
|
||||||
|
"""Convert to dictionary"""
|
||||||
|
return {
|
||||||
|
"id": self.id,
|
||||||
|
"name": self.name,
|
||||||
|
"description": self.description,
|
||||||
|
"depends_on": self.depends_on,
|
||||||
|
"status": self.status.value,
|
||||||
|
"result": self.result,
|
||||||
|
"created_at": self.created_at.isoformat() if self.created_at else None,
|
||||||
|
"updated_at": self.updated_at.isoformat() if self.updated_at else None
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class Task:
|
||||||
|
"""Task entity"""
|
||||||
|
id: str
|
||||||
|
name: str
|
||||||
|
description: str = ""
|
||||||
|
goal: str = ""
|
||||||
|
status: TaskStatus = TaskStatus.PENDING
|
||||||
|
steps: List[Step] = field(default_factory=list)
|
||||||
|
subtasks: List["Task"] = field(default_factory=list)
|
||||||
|
result: Optional[Dict[str, Any]] = None
|
||||||
|
created_at: datetime = field(default_factory=datetime.now)
|
||||||
|
updated_at: datetime = field(default_factory=datetime.now)
|
||||||
|
|
||||||
|
def to_dict(self) -> Dict[str, Any]:
|
||||||
|
"""Convert to dictionary"""
|
||||||
|
return {
|
||||||
|
"id": self.id,
|
||||||
|
"name": self.name,
|
||||||
|
"description": self.description,
|
||||||
|
"goal": self.goal,
|
||||||
|
"status": self.status.value,
|
||||||
|
"steps": [s.to_dict() for s in self.steps],
|
||||||
|
"subtasks": [t.to_dict() for t in self.subtasks],
|
||||||
|
"result": self.result,
|
||||||
|
"created_at": self.created_at.isoformat() if self.created_at else None,
|
||||||
|
"updated_at": self.updated_at.isoformat() if self.updated_at else None
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
class TaskGraph:
|
||||||
|
"""Task graph for managing step dependencies"""
|
||||||
|
|
||||||
|
def __init__(self, task: Task):
|
||||||
|
self.task = task
|
||||||
|
self._adjacency: Dict[str, List[str]] = {}
|
||||||
|
self._reverse_adjacency: Dict[str, List[str]] = {}
|
||||||
|
self._in_degree: Dict[str, int] = {}
|
||||||
|
self._build_graph()
|
||||||
|
|
||||||
|
def _build_graph(self) -> None:
|
||||||
|
"""Build graph from task steps"""
|
||||||
|
for step in self.task.steps:
|
||||||
|
self._adjacency[step.id] = []
|
||||||
|
self._reverse_adjacency[step.id] = []
|
||||||
|
self._in_degree[step.id] = 0
|
||||||
|
|
||||||
|
for step in self.task.steps:
|
||||||
|
for dep_id in step.depends_on:
|
||||||
|
if dep_id in self._adjacency:
|
||||||
|
self._adjacency[dep_id].append(step.id)
|
||||||
|
self._reverse_adjacency[step.id].append(dep_id)
|
||||||
|
self._in_degree[step.id] += 1
|
||||||
|
|
||||||
|
def topological_sort(self) -> List[Step]:
|
||||||
|
"""Get steps in topological order"""
|
||||||
|
in_degree = self._in_degree.copy()
|
||||||
|
queue = [step_id for step_id, degree in in_degree.items() if degree == 0]
|
||||||
|
result = []
|
||||||
|
step_map = {step.id: step for step in self.task.steps}
|
||||||
|
|
||||||
|
while queue:
|
||||||
|
queue.sort()
|
||||||
|
current = queue.pop(0)
|
||||||
|
result.append(step_map[current])
|
||||||
|
|
||||||
|
for dependent_id in self._adjacency[current]:
|
||||||
|
in_degree[dependent_id] -= 1
|
||||||
|
if in_degree[dependent_id] == 0:
|
||||||
|
queue.append(dependent_id)
|
||||||
|
|
||||||
|
return result
|
||||||
|
|
||||||
|
def get_ready_steps(self, completed_step_ids: List[str]) -> List[Step]:
|
||||||
|
"""Get steps that are ready to execute"""
|
||||||
|
step_map = {step.id: step for step in self.task.steps}
|
||||||
|
ready = []
|
||||||
|
|
||||||
|
for step in self.task.steps:
|
||||||
|
if step.id in completed_step_ids:
|
||||||
|
continue
|
||||||
|
if step.status != StepStatus.PENDING:
|
||||||
|
continue
|
||||||
|
deps_completed = all(dep_id in completed_step_ids for dep_id in step.depends_on)
|
||||||
|
if deps_completed:
|
||||||
|
ready.append(step)
|
||||||
|
|
||||||
|
return ready
|
||||||
|
|
||||||
|
def detect_cycles(self) -> List[List[str]]:
|
||||||
|
"""Detect cycles in the graph"""
|
||||||
|
WHITE, GRAY, BLACK = 0, 1, 2
|
||||||
|
color = {step.id: WHITE for step in self.task.steps}
|
||||||
|
cycles = []
|
||||||
|
|
||||||
|
def dfs(node: str, path: List[str]) -> bool:
|
||||||
|
color[node] = GRAY
|
||||||
|
path.append(node)
|
||||||
|
|
||||||
|
for neighbor in self._adjacency.get(node, []):
|
||||||
|
if color[neighbor] == GRAY:
|
||||||
|
cycle_start = path.index(neighbor)
|
||||||
|
cycles.append(path[cycle_start:] + [neighbor])
|
||||||
|
return True
|
||||||
|
elif color[neighbor] == WHITE:
|
||||||
|
if dfs(neighbor, path):
|
||||||
|
return True
|
||||||
|
|
||||||
|
path.pop()
|
||||||
|
color[node] = BLACK
|
||||||
|
return False
|
||||||
|
|
||||||
|
for step in self.task.steps:
|
||||||
|
if color[step.id] == WHITE:
|
||||||
|
dfs(step.id, [])
|
||||||
|
|
||||||
|
return cycles
|
||||||
|
|
||||||
|
def validate(self) -> tuple[bool, Optional[str]]:
|
||||||
|
"""Validate the graph structure"""
|
||||||
|
cycles = self.detect_cycles()
|
||||||
|
if cycles:
|
||||||
|
return False, f"Circular dependency detected: {cycles[0]}"
|
||||||
|
|
||||||
|
step_ids = {step.id for step in self.task.steps}
|
||||||
|
for step in self.task.steps:
|
||||||
|
for dep_id in step.depends_on:
|
||||||
|
if dep_id not in step_ids:
|
||||||
|
return False, f"Step '{step.name}' depends on non-existent step '{dep_id}'"
|
||||||
|
|
||||||
|
return True, None
|
||||||
|
|
||||||
|
|
||||||
|
class TaskService:
|
||||||
|
"""Task service for managing tasks"""
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
self._tasks: Dict[str, Task] = {}
|
||||||
|
|
||||||
|
def create_task(
|
||||||
|
self,
|
||||||
|
name: str,
|
||||||
|
goal: str,
|
||||||
|
description: str = "",
|
||||||
|
steps: List[Dict[str, Any]] = None
|
||||||
|
) -> Task:
|
||||||
|
"""Create a new task"""
|
||||||
|
task_id = generate_id("task")
|
||||||
|
task = Task(
|
||||||
|
id=task_id,
|
||||||
|
name=name,
|
||||||
|
description=description,
|
||||||
|
goal=goal
|
||||||
|
)
|
||||||
|
|
||||||
|
if steps:
|
||||||
|
for step_data in steps:
|
||||||
|
step = Step(
|
||||||
|
id=generate_id("step"),
|
||||||
|
name=step_data.get("name", ""),
|
||||||
|
description=step_data.get("description", "")
|
||||||
|
)
|
||||||
|
task.steps.append(step)
|
||||||
|
|
||||||
|
self._tasks[task_id] = task
|
||||||
|
logger.info(f"Created task: {task_id}")
|
||||||
|
return task
|
||||||
|
|
||||||
|
def get_task(self, task_id: str) -> Optional[Task]:
|
||||||
|
"""Get task by ID"""
|
||||||
|
return self._tasks.get(task_id)
|
||||||
|
|
||||||
|
def list_tasks(self) -> List[Task]:
|
||||||
|
"""List all tasks"""
|
||||||
|
return list(self._tasks.values())
|
||||||
|
|
||||||
|
def update_task_status(
|
||||||
|
self,
|
||||||
|
task_id: str,
|
||||||
|
status: TaskStatus,
|
||||||
|
result: Any = None
|
||||||
|
) -> Optional[Task]:
|
||||||
|
"""Update task status"""
|
||||||
|
task = self._tasks.get(task_id)
|
||||||
|
if not task:
|
||||||
|
return None
|
||||||
|
|
||||||
|
task.status = status
|
||||||
|
task.result = result
|
||||||
|
task.updated_at = datetime.now()
|
||||||
|
return task
|
||||||
|
|
||||||
|
def add_steps(
|
||||||
|
self,
|
||||||
|
task_id: str,
|
||||||
|
steps: List[Dict[str, Any]]
|
||||||
|
) -> Optional[List[Step]]:
|
||||||
|
"""Add steps to task"""
|
||||||
|
task = self._tasks.get(task_id)
|
||||||
|
if not task:
|
||||||
|
return None
|
||||||
|
|
||||||
|
result = []
|
||||||
|
for step_data in steps:
|
||||||
|
step = Step(
|
||||||
|
id=generate_id("step"),
|
||||||
|
name=step_data.get("name", ""),
|
||||||
|
description=step_data.get("description", ""),
|
||||||
|
depends_on=step_data.get("depends_on", [])
|
||||||
|
)
|
||||||
|
task.steps.append(step)
|
||||||
|
result.append(step)
|
||||||
|
|
||||||
|
task.updated_at = datetime.now()
|
||||||
|
return result
|
||||||
|
|
||||||
|
def delete_task(self, task_id: str) -> bool:
|
||||||
|
"""Delete task"""
|
||||||
|
if task_id not in self._tasks:
|
||||||
|
return False
|
||||||
|
|
||||||
|
del self._tasks[task_id]
|
||||||
|
return True
|
||||||
|
|
||||||
|
def build_graph(self, task_id: str) -> Optional[TaskGraph]:
|
||||||
|
"""Build task graph for a task"""
|
||||||
|
task = self._tasks.get(task_id)
|
||||||
|
if not task:
|
||||||
|
return None
|
||||||
|
return TaskGraph(task)
|
||||||
|
|
||||||
|
|
||||||
|
task_service = TaskService()
|
||||||
Loading…
Reference in New Issue