276 lines
10 KiB
Python
276 lines
10 KiB
Python
"""AgenticLoop - Executes the Agentic Loop: LLM + Tools iteration.
|
|
|
|
The loop:
|
|
1. Call LLM with messages and tools
|
|
2. Check for tool calls in response
|
|
3. Execute tools in parallel
|
|
4. Add results to messages
|
|
5. Repeat (max 10 iterations)
|
|
6. Return final response
|
|
"""
|
|
import json
|
|
import uuid
|
|
import logging
|
|
import traceback
|
|
from typing import List, Dict, Any, AsyncGenerator
|
|
|
|
from luxx.tools.executor import ToolExecutor
|
|
from luxx.services.llm_client import LLMClient
|
|
from luxx.services.stream_context import StreamContext, _sse_event
|
|
from luxx.services.process_result import ProcessResult
|
|
from luxx.services.llm_response import llm_parser
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
# Maximum iterations to prevent infinite loops
|
|
MAX_ITERATIONS = 10
|
|
|
|
|
|
def _parse_sse_line(line: str) -> tuple:
|
|
"""Parse SSE line into (event_type, data_str)."""
|
|
event_type = None
|
|
data_str = None
|
|
for part in line.strip().split('\n'):
|
|
if part.startswith('event: '):
|
|
event_type = part[7:].strip()
|
|
elif part.startswith('data: '):
|
|
data_str = part[6:].strip()
|
|
return event_type, data_str
|
|
|
|
|
|
class AgenticLoop:
|
|
"""Executes the Agentic Loop: LLM + Tools iteration."""
|
|
|
|
def __init__(self, tool_executor: ToolExecutor):
|
|
self.tool_executor = tool_executor
|
|
|
|
async def execute(
|
|
self,
|
|
llm: LLMClient,
|
|
model: str,
|
|
messages: List[Dict],
|
|
tools: list,
|
|
temperature: float,
|
|
max_tokens: int,
|
|
thinking_enabled: bool,
|
|
context: 'StreamContext',
|
|
tool_context: dict = None
|
|
) -> AsyncGenerator[str, None]:
|
|
"""Execute the agentic loop.
|
|
|
|
Yields SSE events for each step.
|
|
"""
|
|
total_usage = {"prompt_tokens": 0, "completion_tokens": 0, "total_tokens": 0}
|
|
|
|
for iteration in range(MAX_ITERATIONS):
|
|
context.reset()
|
|
has_error = False
|
|
|
|
# Stream LLM response
|
|
async for sse_line in llm.stream_call(
|
|
model=model,
|
|
messages=messages,
|
|
tools=tools,
|
|
temperature=temperature,
|
|
max_tokens=max_tokens,
|
|
thinking_enabled=thinking_enabled
|
|
):
|
|
# Process stream line
|
|
result = self._process_stream_line(sse_line, context, total_usage)
|
|
|
|
# Yield events
|
|
for event in result.events:
|
|
yield event
|
|
|
|
# Check for errors
|
|
if result.has_error:
|
|
has_error = True
|
|
break
|
|
|
|
# If error occurred, break the loop
|
|
if has_error:
|
|
break
|
|
|
|
# Finalize current step
|
|
context.finalize_step()
|
|
|
|
# Check for tool calls
|
|
if context.tool_calls_list:
|
|
# Execute tools and yield events
|
|
for event in self._execute_tools(context, messages, tool_context):
|
|
yield event
|
|
continue
|
|
|
|
# No tools - complete
|
|
for event in self._complete(context, total_usage):
|
|
yield event
|
|
return
|
|
|
|
# Max iterations exceeded or error occurred
|
|
if not has_error:
|
|
yield _sse_event("error", {"content": "Exceeded maximum tool call iterations"})
|
|
|
|
def _process_stream_line(self, sse_line: str, ctx: 'StreamContext',
|
|
total_usage: dict) -> ProcessResult:
|
|
"""Process single SSE line from LLM, return result with events and flags."""
|
|
result = ProcessResult()
|
|
event_type, data_str = _parse_sse_line(sse_line)
|
|
if not data_str:
|
|
return result
|
|
|
|
# Handle upstream errors
|
|
if event_type == 'error':
|
|
try:
|
|
error_data = json.loads(data_str)
|
|
error_content = error_data.get("content", "Unknown error")
|
|
except json.JSONDecodeError:
|
|
error_content = data_str
|
|
result.set_error(error_content)
|
|
result.add_event(_sse_event("error", {"content": error_content}))
|
|
return result
|
|
|
|
try:
|
|
chunk = json.loads(data_str)
|
|
except json.JSONDecodeError:
|
|
error_msg = f"Parse error: {data_str[:50]}"
|
|
result.set_error(error_msg)
|
|
result.add_event(_sse_event("error", {"content": error_msg}))
|
|
return result
|
|
|
|
# Extract usage
|
|
if "usage" in chunk and chunk["usage"]:
|
|
usage = chunk["usage"]
|
|
total_usage.update({
|
|
"prompt_tokens": usage.get("prompt_tokens", 0),
|
|
"completion_tokens": usage.get("completion_tokens", 0),
|
|
"total_tokens": usage.get("total_tokens", 0)
|
|
})
|
|
|
|
# Handle API errors
|
|
if "error" in chunk:
|
|
error_msg = chunk["error"].get("message", str(chunk["error"]))
|
|
result.set_error(error_msg)
|
|
result.add_event(_sse_event("error", {"content": f"API Error: {error_msg}"}))
|
|
return result
|
|
|
|
# Get delta
|
|
choices = chunk.get("choices", [])
|
|
if not choices:
|
|
# Non-standard format: check for content directly
|
|
content = chunk.get("content") or ""
|
|
if content:
|
|
# Check for thinking tags in content
|
|
thinking_part, clean_text = llm_parser._extract_thinking_tags(content)
|
|
|
|
if thinking_part:
|
|
ctx.full_thinking = (ctx.full_thinking or "") + thinking_part
|
|
if not ctx.current_step_id or ctx.current_step_type != "thinking":
|
|
ctx.start_step("thinking")
|
|
result.add_event(_sse_event("process_step", {
|
|
"step": {"id": ctx.current_step_id, "index": ctx.current_step_idx, "type": "thinking", "content": ctx.full_thinking}
|
|
}))
|
|
result.set_content()
|
|
|
|
if clean_text:
|
|
ctx.full_content = (ctx.full_content or "") + clean_text
|
|
if not ctx.current_step_id or ctx.current_step_type != "text":
|
|
ctx.start_step("text")
|
|
result.add_event(_sse_event("process_step", {
|
|
"step": {"id": ctx.current_step_id, "index": ctx.current_step_idx, "type": "text", "content": ctx.full_content}
|
|
}))
|
|
result.set_content()
|
|
return result
|
|
|
|
delta = choices[0].get("delta", {})
|
|
|
|
# Parse delta using unified parser
|
|
parsed = llm_parser.parse_openai(delta)
|
|
|
|
# Process thinking content
|
|
if parsed.thinking:
|
|
ctx.full_thinking = parsed.thinking
|
|
if not ctx.current_step_id or ctx.current_step_type != "thinking":
|
|
ctx.start_step("thinking")
|
|
result.add_event(_sse_event("process_step", {
|
|
"step": {
|
|
"id": ctx.current_step_id,
|
|
"index": ctx.current_step_idx,
|
|
"type": "thinking",
|
|
"content": ctx.full_thinking
|
|
}
|
|
}))
|
|
result.set_content()
|
|
|
|
# Process text content
|
|
if parsed.text:
|
|
ctx.full_content = parsed.text
|
|
if not ctx.current_step_id or ctx.current_step_type != "text":
|
|
ctx.start_step("text")
|
|
result.add_event(_sse_event("process_step", {
|
|
"step": {
|
|
"id": ctx.current_step_id,
|
|
"index": ctx.current_step_idx,
|
|
"type": "text",
|
|
"content": ctx.full_content
|
|
}
|
|
}))
|
|
result.set_content()
|
|
|
|
# Accumulate tool calls
|
|
for tc in parsed.tool_calls or delta.get("tool_calls", []):
|
|
ctx.accumulate_tool_call(tc)
|
|
result.set_tool_calls()
|
|
|
|
return result
|
|
|
|
def _execute_tools(self, ctx: 'StreamContext', messages: list,
|
|
tool_context: dict = None) -> List[str]:
|
|
"""Execute tools and return list of events."""
|
|
events = []
|
|
|
|
# Emit tool call steps
|
|
for event in ctx.emit_tool_calls():
|
|
events.append(event)
|
|
|
|
# Execute in parallel
|
|
tool_results = self.tool_executor.process_tool_calls_parallel(
|
|
ctx.tool_calls_list, tool_context or {}
|
|
)
|
|
|
|
# Get tool call IDs for result linking
|
|
tool_ids = [tc.get("id") for tc in ctx.tool_calls_list]
|
|
tool_step_ids = [
|
|
s["id"] for s in ctx.all_steps
|
|
if s["type"] == "tool_call" and s.get("id_ref") in tool_ids
|
|
]
|
|
|
|
# Emit tool result steps
|
|
for i, (tr, tc) in enumerate(zip(tool_results, ctx.tool_calls_list)):
|
|
ref_id = tool_step_ids[i] if i < len(tool_step_ids) else f"step-{len(ctx.all_steps) - len(tool_results) + i}"
|
|
_, event = ctx.emit_tool_result(tr, ref_id)
|
|
events.append(event)
|
|
|
|
# Prepare for next iteration
|
|
messages.append({
|
|
"role": "assistant",
|
|
"content": ctx.full_content or "",
|
|
"tool_calls": ctx.tool_calls_list
|
|
})
|
|
messages.extend(ctx.all_tool_results[-len(tool_results):])
|
|
|
|
return events
|
|
|
|
def _complete(self, ctx: 'StreamContext', total_usage: dict) -> List[str]:
|
|
"""Complete the loop and return list of events."""
|
|
token_count = total_usage.get("completion_tokens") or len(ctx.full_content) // 4
|
|
msg_id = str(uuid.uuid4())
|
|
logger.info(f"[TOKEN] usage={total_usage}, count={token_count}")
|
|
|
|
ctx.set_completion(msg_id, token_count, total_usage)
|
|
|
|
return [_sse_event("done", {
|
|
"message_id": msg_id,
|
|
"token_count": token_count,
|
|
"usage": total_usage
|
|
})]
|