263 lines
9.3 KiB
Python
263 lines
9.3 KiB
Python
"""Tool executor with caching and parallel execution support
|
|
|
|
This module follows the Single Responsibility Principle:
|
|
- ToolExecutor: Tool execution logic
|
|
- CallHistory: Call history management
|
|
- CacheManager: Caching logic
|
|
"""
|
|
import json
|
|
import time
|
|
import logging
|
|
from typing import List, Dict, Any, Optional
|
|
from threading import Lock
|
|
from luxx.tools.core import registry, ToolContext
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
class CacheManager:
|
|
"""Manages tool result caching"""
|
|
|
|
def __init__(self, enable_cache: bool = True, cache_ttl: int = 300):
|
|
self.enable_cache = enable_cache
|
|
self.cache_ttl = cache_ttl
|
|
self._cache: Dict[str, tuple] = {} # key: (result, timestamp)
|
|
self._lock = Lock()
|
|
|
|
def make_key(self, name: str, args: dict, workspace: str = None) -> str:
|
|
"""Generate cache key"""
|
|
args_str = json.dumps(args, sort_keys=True, ensure_ascii=False)
|
|
key = f"{name}:{args_str}"
|
|
if workspace:
|
|
key = f"{key}:{workspace}"
|
|
return key
|
|
|
|
def is_valid(self, cache_key: str) -> bool:
|
|
"""Check if cache is valid"""
|
|
if cache_key not in self._cache:
|
|
return False
|
|
_, timestamp = self._cache[cache_key]
|
|
return time.time() - timestamp < self.cache_ttl
|
|
|
|
def get(self, cache_key: str) -> Optional[Dict]:
|
|
"""Get cached result"""
|
|
if not self.enable_cache:
|
|
return None
|
|
if self.is_valid(cache_key):
|
|
return self._cache[cache_key][0]
|
|
return None
|
|
|
|
def set(self, cache_key: str, result: Dict) -> None:
|
|
"""Set cache"""
|
|
if not self.enable_cache:
|
|
return
|
|
with self._lock:
|
|
self._cache[cache_key] = (result, time.time())
|
|
|
|
def clear(self) -> None:
|
|
"""Clear all cache"""
|
|
with self._lock:
|
|
self._cache.clear()
|
|
|
|
def size(self) -> int:
|
|
"""Get cache size"""
|
|
return len(self._cache)
|
|
|
|
|
|
class CallHistory:
|
|
"""Manages tool call history"""
|
|
|
|
MAX_HISTORY_SIZE = 1000
|
|
|
|
def __init__(self):
|
|
self._history: List[Dict[str, Any]] = []
|
|
self._lock = Lock()
|
|
|
|
def record(self, name: str, args: dict, result: Dict) -> None:
|
|
"""Record a tool call"""
|
|
entry = {
|
|
"name": name,
|
|
"args": args,
|
|
"result": result,
|
|
"timestamp": time.time()
|
|
}
|
|
with self._lock:
|
|
self._history.append(entry)
|
|
# Limit history size
|
|
if len(self._history) > self.MAX_HISTORY_SIZE:
|
|
self._history = self._history[-self.MAX_HISTORY_SIZE:]
|
|
|
|
def get(self, limit: int = 100) -> List[Dict[str, Any]]:
|
|
"""Get recent call history"""
|
|
with self._lock:
|
|
return self._history[-limit:].copy()
|
|
|
|
def clear(self) -> None:
|
|
"""Clear all history"""
|
|
with self._lock:
|
|
self._history.clear()
|
|
|
|
def size(self) -> int:
|
|
"""Get history size"""
|
|
return len(self._history)
|
|
|
|
|
|
class ToolExecutor:
|
|
"""Tool executor with caching and parallel execution support
|
|
|
|
This class delegates caching and history to specialized classes,
|
|
following the Single Responsibility Principle.
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
enable_cache: bool = True,
|
|
cache_ttl: int = 300,
|
|
max_workers: int = 4
|
|
):
|
|
self.cache = CacheManager(enable_cache=enable_cache, cache_ttl=cache_ttl)
|
|
self.history = CallHistory()
|
|
self.max_workers = max_workers
|
|
|
|
def process_tool_calls(
|
|
self,
|
|
tool_calls: List[Dict[str, Any]],
|
|
context: Dict[str, Any]
|
|
) -> List[Dict[str, Any]]:
|
|
"""Process tool calls sequentially"""
|
|
# Build ToolContext from context dict
|
|
tool_ctx = self._build_tool_context(context)
|
|
|
|
results = []
|
|
|
|
for call in tool_calls:
|
|
call_id = call.get("id", "")
|
|
name = call.get("function", {}).get("name", "")
|
|
|
|
# Parse JSON arguments
|
|
args = self._parse_arguments(call)
|
|
|
|
# Check cache
|
|
cache_key = self.cache.make_key(name, args, tool_ctx.workspace)
|
|
cached = self.cache.get(cache_key)
|
|
|
|
if cached is not None:
|
|
result = cached
|
|
else:
|
|
# Execute tool with context
|
|
result = registry.execute(name, args, context=tool_ctx)
|
|
self.cache.set(cache_key, result)
|
|
|
|
# Record call
|
|
self.history.record(name, args, result)
|
|
|
|
# Create result message
|
|
results.append(self._create_tool_result(call_id, name, result))
|
|
|
|
return results
|
|
|
|
def process_tool_calls_parallel(
|
|
self,
|
|
tool_calls: List[Dict[str, Any]],
|
|
context: Dict[str, Any]
|
|
) -> List[Dict[str, Any]]:
|
|
"""Process tool calls in parallel
|
|
|
|
IMPORTANT: Results are returned in the SAME ORDER as input tool_calls,
|
|
not in completion order. This ensures proper matching between tool_call
|
|
and tool_result steps in the frontend.
|
|
"""
|
|
|
|
|
|
if len(tool_calls) <= 1:
|
|
return self.process_tool_calls(tool_calls, context)
|
|
|
|
tool_ctx = self._build_tool_context(context)
|
|
|
|
try:
|
|
from concurrent.futures import ThreadPoolExecutor
|
|
|
|
# Store futures with their original index to maintain order
|
|
futures_with_index = {}
|
|
results = [None] * len(tool_calls)
|
|
|
|
with ThreadPoolExecutor(max_workers=self.max_workers) as executor:
|
|
for idx, call in enumerate(tool_calls):
|
|
call_id = call.get("id", "")
|
|
name = call.get("function", {}).get("name", "")
|
|
args = self._parse_arguments(call)
|
|
|
|
# Check cache
|
|
cache_key = self.cache.make_key(name, args, tool_ctx.workspace)
|
|
cached = self.cache.get(cache_key)
|
|
|
|
if cached is not None:
|
|
self.history.record(name, args, cached)
|
|
# Store result at the correct index
|
|
results[idx] = self._create_tool_result(call_id, name, cached)
|
|
else:
|
|
# Submit task with index
|
|
future = executor.submit(
|
|
registry.execute, name, args, context=tool_ctx
|
|
)
|
|
futures_with_index[future] = (idx, call_id, name, args, cache_key)
|
|
|
|
# Wait for all futures and store results at correct indices
|
|
for future in futures_with_index:
|
|
idx, call_id, name, args, cache_key = futures_with_index[future]
|
|
try:
|
|
result = future.result()
|
|
self.cache.set(cache_key, result)
|
|
self.history.record(name, args, result)
|
|
results[idx] = self._create_tool_result(call_id, name, result)
|
|
except Exception as e:
|
|
logger.error(f"[EXECUTOR] Tool '{name}' execution failed: {type(e).__name__}: {e}")
|
|
# Create error result
|
|
error_result = {"success": False, "error": f"{type(e).__name__}: {str(e)}"}
|
|
self.history.record(name, args, error_result)
|
|
results[idx] = self._create_tool_result(call_id, name, error_result)
|
|
|
|
# Filter out None values (shouldn't happen, but safety check)
|
|
return [r for r in results if r is not None]
|
|
|
|
except ImportError:
|
|
return self.process_tool_calls(tool_calls, context)
|
|
except Exception as e:
|
|
logger.error(f"[EXECUTOR] Parallel execution failed: {type(e).__name__}: {e}")
|
|
# Fallback to sequential execution
|
|
return self.process_tool_calls(tool_calls, context)
|
|
|
|
def _build_tool_context(self, context: Dict[str, Any]) -> ToolContext:
|
|
"""Build ToolContext from context dict"""
|
|
return ToolContext(
|
|
workspace=context.get("workspace"),
|
|
user_id=context.get("user_id"),
|
|
username=context.get("username"),
|
|
extra={
|
|
"user_permission_level": context.get("user_permission_level", 1),
|
|
**(context.get("extra", {}))
|
|
}
|
|
)
|
|
|
|
def _parse_arguments(self, call: Dict[str, Any]) -> Dict:
|
|
"""Parse JSON arguments from tool call"""
|
|
try:
|
|
return json.loads(call.get("function", {}).get("arguments", "{}"))
|
|
except json.JSONDecodeError:
|
|
return {}
|
|
|
|
def _create_tool_result(self, call_id: str, name: str, result: Dict) -> Dict[str, Any]:
|
|
"""Create tool result message (OpenAI format: only tool_call_id and content)"""
|
|
return {
|
|
"tool_call_id": call_id,
|
|
"role": "tool",
|
|
"content": json.dumps(result, ensure_ascii=False)
|
|
}
|
|
|
|
def clear_cache(self) -> None:
|
|
"""Clear all cache"""
|
|
self.cache.clear()
|
|
|
|
def get_history(self, limit: int = 100) -> List[Dict[str, Any]]:
|
|
"""Get call history"""
|
|
return self.history.get(limit)
|