Files
local_swarm/src/api/chat_handlers.py
T
sleepy 93844a81b0 refactor: unified generation interface for federation and local modes
- Created _generate_with_consensus() that handles both federation and local generation
- Callers don't need to know which mode is being used - it's transparent
- Tool execution loop uses same unified interface for all iterations
- Removed special-case federation logic from main handler
- Federation is now a transparent layer around generation
- All 41 tests passing
2026-02-25 23:36:24 +01:00

744 lines
30 KiB
Python

"""Chat completion handlers for Local Swarm.
Contains the business logic for chat completions, separated from HTTP routing.
"""
import json
import logging
import time
import uuid
from typing import Optional, List
from api.models import (
ChatCompletionRequest,
ChatCompletionResponse,
ChatCompletionChoice,
ChatMessage,
UsageInfo,
)
from api.formatting import format_messages_with_tools
from api.tool_parser import parse_tool_calls
from utils.token_counter import count_tokens
from tools.executor import get_tool_executor
from chatlog import get_chat_logger
logger = logging.getLogger(__name__)
def _extract_working_dir_from_prompt(prompt: str) -> Optional[str]:
"""Extract working directory from user prompt.
Looks for patterns like:
- "in the /path/to/dir directory"
- "in directory /path/to/dir"
- "in /path/to/dir"
- "under /path/to/dir"
- "from /path/to/dir"
Args:
prompt: User prompt text
Returns:
Extracted directory path or None
"""
import re
import os
# Common patterns for directory mentions
patterns = [
r'in the\s+([/~]?[\w\-/.]+)\s+(?:directory|folder|dir)',
r'in\s+(?:directory|folder|dir)\s+([/~]?[\w\-/.]+)',
r'(?:in|under|from|at)\s+([/~]?[\w\-/.]{3,})', # At least 3 chars to avoid "in a"
]
for pattern in patterns:
match = re.search(pattern, prompt, re.IGNORECASE)
if match:
path = match.group(1)
# Validate it looks like a path
if path.startswith('/') or path.startswith('~') or '/' in path:
# Expand home directory
if path.startswith('~'):
path = os.path.expanduser(path)
# Check if it's a valid directory or parent exists
if os.path.isdir(path) or os.path.isdir(os.path.dirname(path)):
return os.path.abspath(path)
return None
def _sanitize_tools(tools: Optional[list]) -> Optional[list]:
"""Sanitize tool definitions to fix invalid schemas.
Removes extra 'description' from properties if present.
Args:
tools: List of tool definitions
Returns:
Sanitized tools list
"""
if not tools:
return tools
sanitized = []
for tool in tools:
if tool.type == "function" and tool.function.parameters:
params = tool.function.parameters
# Remove invalid 'description' from properties if present
if 'properties' in params and 'description' in params.get('properties', {}):
invalid_props = ['description']
# Also remove 'description' from required if present
if 'required' in params:
params['required'] = [r for r in params.get('required', []) if r not in invalid_props]
# Remove invalid properties
params['properties'] = {k: v for k, v in params.get('properties', {}).items() if k not in invalid_props}
logger.debug(f" 🔧 Sanitized tool '{tool.function.name}': removed {invalid_props} from properties/required")
sanitized.append(tool)
return sanitized
async def _execute_tools(
tool_calls: list,
client_working_dir: Optional[str],
executor
) -> List[tuple]:
"""Execute tool calls and return results.
Args:
tool_calls: List of parsed tool calls
client_working_dir: Working directory for file operations
executor: Tool executor instance
Returns:
List of tuples (tool_name, result_string)
"""
from api.routes import execute_tool_server_side
tool_results = []
for i, tc in enumerate(tool_calls):
tool_name = tc.get("function", {}).get("name", "")
tool_args_str = tc.get("function", {}).get("arguments", "{}")
try:
tool_args = json.loads(tool_args_str) if isinstance(tool_args_str, str) else tool_args_str
except:
tool_args = {}
logger.debug(f" [{i+1}/{len(tool_calls)}] Executing: {tool_name}({tool_args})")
result = await execute_tool_server_side(tool_name, tool_args, working_dir=client_working_dir)
tool_results.append((tool_name, result))
logger.debug(f" ✓ Completed: {result[:100]}..." if len(result) > 100 else f" ✓ Result: {result}")
return tool_results
def _create_response(
content: str,
tool_calls: list,
finish_reason: str,
prompt: str,
request: ChatCompletionRequest,
swarm_manager=None,
thinking_content: Optional[str] = None
) -> ChatCompletionResponse:
"""Create a chat completion response.
Args:
content: Final response content (after tool execution if any)
tool_calls: List of tool calls
finish_reason: Finish reason
prompt: Original prompt for token counting
request: Original request
swarm_manager: Swarm manager instance (optional, for getting model name)
thinking_content: Intermediate thinking/planning content to include in streaming as reasoning_content
Returns:
ChatCompletionResponse
"""
"""Create a chat completion response.
Args:
content: Response content
tool_calls: List of tool calls
finish_reason: Finish reason
prompt: Original prompt for token counting
request: Original request
swarm_manager: Swarm manager instance (optional, for getting model name)
Returns:
ChatCompletionResponse
"""
# Ensure content is at least an empty string (never None for OpenAI compatibility)
if content is None:
content = ""
prompt_tokens = count_tokens(prompt)
completion_tokens = count_tokens(content)
total_tokens = prompt_tokens + completion_tokens
# Get actual model name from swarm manager
model_name = request.model
system_fingerprint = None
if swarm_manager:
status = swarm_manager.get_status()
model_name = status.model_name
# Sanitize system_fingerprint to only include safe characters
import re
raw_fingerprint = model_name.lower().replace(" ", "-")
system_fingerprint = re.sub(r'[^a-z0-9\-_]', '', raw_fingerprint)
# Build message - omit tool_calls entirely if empty (OpenAI behavior)
message_kwargs = {
"role": "assistant",
"content": content
}
if tool_calls:
message_kwargs["tool_calls"] = tool_calls
message = ChatMessage(**message_kwargs)
response = ChatCompletionResponse(
id=f"chatcmpl-{uuid.uuid4().hex[:12]}",
created=int(time.time()),
model=model_name,
choices=[
ChatCompletionChoice(
index=0,
message=message,
logprobs=None,
finish_reason=finish_reason
)
],
usage=UsageInfo(
prompt_tokens=prompt_tokens,
completion_tokens=completion_tokens,
total_tokens=total_tokens
),
stats={},
system_fingerprint=system_fingerprint
)
# Attach thinking content for streaming (not part of JSON serialization)
# Use a private attribute to avoid interfering with model serialization
if thinking_content is not None:
setattr(response, '_thinking', thinking_content)
return response
async def _generate_with_consensus(
prompt: str,
max_tokens: int,
temperature: float,
swarm_manager,
federated_swarm=None
) -> tuple[str, int, float]:
"""Generate response with consensus (local or federated).
This is the unified generation interface - it handles both local-only
and federated generation transparently. Callers don't need to know
which mode is being used.
Args:
prompt: Prompt to generate from
max_tokens: Maximum tokens to generate
temperature: Sampling temperature
swarm_manager: Local swarm manager instance
federated_swarm: Optional federated swarm for multi-node consensus
Returns:
Tuple of (response_text, tokens_generated, tokens_per_second)
"""
# Check if federation is available
if federated_swarm is not None:
peers = federated_swarm.discovery.get_peers()
if peers:
logger.debug(f"🌐 Using federation with {len(peers)} peer(s)")
try:
content, tool_calls, finish_reason = await federated_swarm.generate_with_federation(
prompt=prompt,
max_tokens=max_tokens,
temperature=temperature
)
# Federation returns content directly
# Note: tool_calls from federation should be ignored - head node handles tools
return content, 0, 0.0 # Tokens/TPS not tracked in federation mode
except Exception as e:
logger.warning(f"Federation failed, falling back to local: {e}")
# Fall through to local generation
# Local generation (fallback or no federation)
try:
result = await swarm_manager.generate(
prompt=prompt,
max_tokens=max_tokens,
temperature=temperature,
use_consensus=True
)
response = result.selected_response
return response.text, response.tokens_generated, response.tokens_per_second
except Exception as e:
logger.exception("Error in swarm generation")
raise
def _tool_calls_agree(tool_calls_list: List[List[dict]]) -> bool:
"""Check if all workers agree on the same tool calls.
Args:
tool_calls_list: List of tool calls from each worker
Returns:
True if all workers have the same tool calls
"""
if not tool_calls_list:
return True
# Check if all have the same number of tool calls
first_count = len(tool_calls_list[0])
if not all(len(tc) == first_count for tc in tool_calls_list):
logger.warning(f" ⚠️ Workers disagree on number of tool calls: {[len(tc) for tc in tool_calls_list]}")
return False
if first_count == 0:
return True # All agree on no tools
# Check if tool names and arguments match
for i in range(first_count):
first_tool = tool_calls_list[0][i]
first_name = first_tool.get("function", {}).get("name", "")
first_args = first_tool.get("function", {}).get("arguments", "")
for j, other_calls in enumerate(tool_calls_list[1:], 1):
other_tool = other_calls[i]
other_name = other_tool.get("function", {}).get("name", "")
other_args = other_tool.get("function", {}).get("arguments", "")
if first_name != other_name:
logger.warning(f" ⚠️ Worker {j+1} disagrees on tool name: {first_name} vs {other_name}")
return False
# For arguments, do a loose comparison (ignore whitespace differences)
try:
first_args_norm = json.loads(first_args) if isinstance(first_args, str) else first_args
other_args_norm = json.loads(other_args) if isinstance(other_args, str) else other_args
if first_args_norm != other_args_norm:
logger.warning(f" ⚠️ Worker {j+1} disagrees on arguments for {first_name}")
return False
except json.JSONDecodeError:
# If JSON parsing fails, compare as strings
if str(first_args).strip() != str(other_args).strip():
logger.warning(f" ⚠️ Worker {j+1} disagrees on arguments for {first_name}")
return False
logger.info(f" ✅ All {len(tool_calls_list)} workers agree on tool calls")
return True
async def _generate_with_tool_consensus(
swarm_manager,
prompt: str,
max_tokens: int,
temperature: float
) -> tuple[str, List[dict], int, float]:
"""Generate response with tool call consensus checking.
When multiple workers are active, this ensures they all agree on tool calls
before executing them. If they disagree, returns the best response without tools.
Args:
swarm_manager: Swarm manager instance
prompt: Prompt to generate from
max_tokens: Maximum tokens to generate
temperature: Sampling temperature
Returns:
Tuple of (response_text, tool_calls, tokens_generated, tps)
"""
try:
# Get status to check number of workers
status = swarm_manager.get_status()
num_workers = getattr(status, 'active_workers', 1)
# If only one worker, use normal generation
if num_workers <= 1:
logger.debug(" Single worker mode - skipping tool consensus")
result = await swarm_manager.generate(
prompt=prompt,
max_tokens=max_tokens,
temperature=temperature,
use_consensus=True
)
response = result.selected_response
parsed_content, tool_calls = parse_tool_calls(response.text)
return response.text, tool_calls, response.tokens_generated, response.tokens_per_second
# Multiple workers - check for tool consensus
logger.info(f" 🔍 Checking tool consensus across {num_workers} workers...")
# Generate from all workers individually
from swarm.manager import GenerationRequest
all_responses = []
all_tool_calls = []
# Get all active workers
workers = swarm_manager.workers if hasattr(swarm_manager, 'workers') else []
if not workers:
# Fall back to normal generation
result = await swarm_manager.generate(
prompt=prompt,
max_tokens=max_tokens,
temperature=temperature,
use_consensus=True
)
response = result.selected_response
parsed_content, tool_calls = parse_tool_calls(response.text)
return response.text, tool_calls, response.tokens_generated, response.tokens_per_second
# Generate from each worker
for i, worker in enumerate(workers):
try:
gen_result = await worker.generate(
GenerationRequest(prompt=prompt, max_tokens=max_tokens, temperature=temperature)
)
response_text = gen_result.text
parsed_content, tool_calls = parse_tool_calls(response_text)
all_responses.append(response_text)
all_tool_calls.append(tool_calls)
logger.debug(f" Worker {i+1}: {len(tool_calls)} tool call(s)")
except Exception as e:
logger.warning(f" Worker {i+1} failed: {e}")
all_responses.append("")
all_tool_calls.append([])
# Check consensus
if _tool_calls_agree(all_tool_calls):
# All agree - use the first response's tool calls
best_response = all_responses[0] if all_responses else ""
best_tool_calls = all_tool_calls[0] if all_tool_calls else []
total_tokens = sum(len(r.split()) for r in all_responses if r) // len([r for r in all_responses if r])
avg_tps = 10.0 # Estimate
return best_response, best_tool_calls, total_tokens, avg_tps
else:
# Disagreement - fall back to consensus strategy without tools
logger.warning(" ⚠️ Tool consensus failed - falling back to text response")
result = await swarm_manager.generate(
prompt=prompt,
max_tokens=max_tokens,
temperature=temperature,
use_consensus=True
)
response = result.selected_response
# Strip any tool calls to be safe
parsed_content, _ = parse_tool_calls(response.text)
return parsed_content, [], response.tokens_generated, response.tokens_per_second
except Exception as e:
logger.exception("Error in tool consensus generation")
# Fall back to normal generation
result = await swarm_manager.generate(
prompt=prompt,
max_tokens=max_tokens,
temperature=temperature,
use_consensus=True
)
response = result.selected_response
parsed_content, tool_calls = parse_tool_calls(response.text)
return response.text, tool_calls, response.tokens_generated, response.tokens_per_second
async def _generate_with_federation(
federated_swarm,
prompt: str,
max_tokens: int,
temperature: float
) -> tuple[str, list, str]:
"""Generate response using federated swarm.
Args:
federated_swarm: Federated swarm instance
prompt: Prompt to generate from
max_tokens: Maximum tokens to generate
temperature: Sampling temperature
Returns:
Tuple of (response_text, tool_calls, finish_reason)
"""
result = await federated_swarm.generate_with_federation(
prompt=prompt,
max_tokens=max_tokens,
temperature=temperature,
min_peers=0
)
content = result.final_response or ""
# Check for tool calls
content_parsed, tool_calls_parsed = parse_tool_calls(content)
if tool_calls_parsed:
return content_parsed or "", tool_calls_parsed, "tool_calls"
return content or "", [], "stop"
async def handle_chat_completion(
request: ChatCompletionRequest,
swarm_manager,
federated_swarm,
client_working_dir: Optional[str],
use_opencode_tools: bool
) -> ChatCompletionResponse:
"""Handle a chat completion request.
Args:
request: Chat completion request
swarm_manager: Swarm manager instance
federated_swarm: Optional federated swarm instance
client_working_dir: Client working directory
use_opencode_tools: Whether to use opencode tool definitions
Returns:
Chat completion response
"""
# Format messages into prompt
if use_opencode_tools:
sanitized_tools = _sanitize_tools(request.tools)
prompt = format_messages_with_tools(request.messages, sanitized_tools)
has_tools = sanitized_tools is not None and len(sanitized_tools) > 0
else:
prompt = format_messages_with_tools(request.messages, None)
has_tools = request.tools is not None and len(request.tools) > 0
# Initialize chat logger (if enabled via LOCAL_SWARM_CHATLOG=1)
chat_logger = get_chat_logger()
# Extract working directory from prompt if not provided by client
if client_working_dir is None:
# Try to extract from user messages
for msg in reversed(request.messages):
if msg.role == 'user':
extracted_dir = _extract_working_dir_from_prompt(msg.content)
if extracted_dir:
client_working_dir = extracted_dir
logger.info(f"📁 Extracted working directory from prompt: {client_working_dir}")
break
# Log initial conversation history to chatlog
for msg in request.messages:
if msg.role == 'user':
chat_logger.log_user_message(msg.content)
elif msg.role == 'assistant':
chat_logger.log_assistant_message(msg.content, has_tool_calls=bool(msg.tool_calls))
elif msg.role == 'tool':
chat_logger.log_tool_result("tool", msg.content)
logger.info(f"\n{'='*60}")
logger.info(f"CHAT COMPLETION REQUEST:")
logger.info(f" has_tools={has_tools}, stream={request.stream}")
logger.info(f" use_opencode={use_opencode_tools}")
logger.info(f" messages={len(request.messages)}")
logger.info(f"{'='*60}")
# Build conversation history
messages = list(request.messages)
# Determine if we should use federation for generation
use_federation = federated_swarm is not None and len(federated_swarm.discovery.get_peers()) > 0
if use_federation:
logger.info(f"🌐 Federation available with peers")
# Track thinking content for streaming (OpenCode reasoning_content)
thinking_content: Optional[str] = None
thinking_captured = False
# Initialize iteration counter and response text
iteration = 0
max_iterations = 3
response_text = ""
while iteration < max_iterations:
iteration += 1
logger.info(f"--- Tool Execution Iteration {iteration} ---")
# Generate response (unified interface - handles federation automatically)
logger.debug(f"Generating response...")
response_text, tokens_generated, tps = await _generate_with_consensus(
prompt=prompt,
max_tokens=request.max_tokens or 1024,
temperature=request.temperature or 0.7,
swarm_manager=swarm_manager,
federated_swarm=federated_swarm
)
logger.info(f"Generated response ({len(response_text)} chars, {tokens_generated} tokens)")
logger.debug(f"Response: {response_text[:200]}...")
# Check for tool calls
parsed_content, tool_calls_parsed = parse_tool_calls(response_text)
# Log assistant response to chatlog
chat_logger.log_assistant_message(response_text, has_tool_calls=bool(tool_calls_parsed))
if tool_calls_parsed:
# Log each tool call
for i, tc in enumerate(tool_calls_parsed, 1):
tool_name = tc.get("function", {}).get("name", "")
args_str = tc.get("function", {}).get("arguments", "{}")
try:
args_dict = json.loads(args_str) if isinstance(args_str, str) else args_str
except json.JSONDecodeError:
args_dict = {"raw": args_str}
chat_logger.log_tool_call(tool_name, args_dict, i)
# Capture thinking for OpenCode streaming (first occurrence only)
if not thinking_captured:
# Use the parsed content (without tool calls) as the reasoning
thinking_content = parsed_content or ""
thinking_captured = True
if not tool_calls_parsed:
# No more tools - this is the final answer
logger.info(f"✅ Final answer (no tools) after {iteration} iteration(s)")
return _create_response(parsed_content, [], "stop", prompt, request, swarm_manager, thinking_content)
# Tools detected - execute them
logger.info(f"🔧 Found {len(tool_calls_parsed)} tool call(s)")
for i, tc in enumerate(tool_calls_parsed):
tool_name = tc.get("function", {}).get("name", "")
args_str = tc.get("function", {}).get("arguments", "{}")
logger.info(f" [{i+1}] {tool_name}: {args_str[:100]}...")
# Add assistant message to history with tool_calls (if any)
# This preserves the tool call IDs for proper tool message association
assistant_message = ChatMessage(
role="assistant",
content=response_text
)
if tool_calls_parsed:
# Convert tool calls to proper ToolCall objects with IDs
from api.models import ToolCall
tc_objects = []
for i, tc_dict in enumerate(tool_calls_parsed):
tc_id = tc_dict.get("id", f"call_{i}")
tc_objects.append(ToolCall(
id=tc_id,
type="function",
function={
"name": tc_dict["function"]["name"],
"arguments": tc_dict["function"]["arguments"]
}
))
assistant_message.tool_calls = tc_objects
messages.append(assistant_message)
# Execute all tools
logger.info(f"⏱️ Executing tools...")
tool_results = await _execute_tools(tool_calls_parsed, client_working_dir, get_tool_executor())
# Log tool results to chatlog (single combined log for debugging)
combined_strings = [f"Tool {i+1} ({name}): {result}" for i, (name, result) in enumerate(tool_results)]
chat_logger.log_tool_result("combined", "\n\n".join(combined_strings), success=True)
# Add tool result to history - one message per tool call with proper tool_call_id
for i, ((tool_name, tool_result), tc) in enumerate(zip(tool_results, tool_calls_parsed)):
tool_call_id = tc.get("id", f"call_{i}")
# Format the tool result message with explicit instructions
# This tells the model exactly what to do with the result
if tool_name == "read":
instruction = "The file contents are shown above. READ THIS FILE CONTENT ALOUD to the user. Do not call additional tools."
elif tool_name == "write":
instruction = "The file has been successfully written. CONFIRM to the user that the file was created with the content shown above. Do not call additional tools."
elif tool_name == "bash":
# Check if this was a verification command (ls, grep) vs an action command
if "ls" in tool_result.lower() or "grep" in tool_result.lower():
instruction = "CRITICAL: The listing is shown above. If the user asked to READ a specific file and you can see it exists in this listing, you MUST immediately USE THE read TOOL NOW with the exact filename from the listing. Do not summarize first - READ THE FILE immediately. Use the filename exactly as shown (e.g., 'my-secret.log' not '/path/to/my-secret.log'). If the user asked to just CHECK what files exist (without reading), then summarize. If the requested file is NOT in the listing, tell the user it doesn't exist."
else:
instruction = "The command has been executed. SUMMARIZE the output above to answer the user's request. Do not call additional tools."
else:
instruction = "The tool has completed. Use the result shown above to answer the user's request. Do not call additional tools."
tool_message_content = (
f"Tool Result ({tool_name}):\n"
f"{tool_result}\n\n"
f"INSTRUCTION: {instruction}"
)
messages.append(ChatMessage(
role="tool",
content=tool_message_content,
tool_call_id=tool_call_id,
name=tool_name
))
logger.info(f" ✓ Tool result {i+1} added to history (tool_call_id={tool_call_id}, name={tool_name})")
logger.info(f"✅ Tools executed ({len(tool_results)} results)")
# Continue loop - generate response with tool results
logger.info(f"🔄 Generating response with tool results...")
# Format with tool results (but DON'T include tool instruction - model should just use results)
next_prompt = format_messages_with_tools(messages, None if use_opencode_tools else request.tools)
logger.info(f"📤 Prompt sent to model after tool execution:")
logger.info(f" Total tokens: {count_tokens(next_prompt)}")
logger.info(f" Messages in history: {len(messages)}")
for i, msg in enumerate(messages):
logger.info(f" [{i}] {msg.role}: {msg.content[:100]}{'...' if len(msg.content) > 100 else ''}")
if msg.tool_calls:
for j, tc in enumerate(msg.tool_calls):
logger.info(f" Tool call {j}: {tc.function.get('name')} ({tc.function.get('arguments')})")
if msg.tool_call_id:
logger.info(f" (tool_call_id: {msg.tool_call_id}, name: {msg.name})")
logger.debug(f"Full prompt:\n{next_prompt[:1000]}...")
response_text, tokens_generated, tps = await _generate_with_consensus(
prompt=next_prompt,
max_tokens=request.max_tokens or 1024,
temperature=request.temperature or 0.7,
swarm_manager=swarm_manager,
federated_swarm=federated_swarm
)
logger.info(f"✅ Generated with tool results ({len(response_text)} chars, {tokens_generated} tokens)")
logger.debug(f"Response: {response_text[:200]}...")
# Check for more tools in the new response
parsed_content, tool_calls_parsed = parse_tool_calls(response_text)
# Log assistant response to chatlog
chat_logger.log_assistant_message(response_text, has_tool_calls=bool(tool_calls_parsed))
if tool_calls_parsed:
# Log each tool call
for i, tc in enumerate(tool_calls_parsed, 1):
tool_name = tc.get("function", {}).get("name", "")
args_str = tc.get("function", {}).get("arguments", "{}")
try:
args_dict = json.loads(args_str) if isinstance(args_str, str) else args_str
except json.JSONDecodeError:
args_dict = {"raw": args_str}
chat_logger.log_tool_call(tool_name, args_dict, i)
# Capture thinking if not already captured
if not thinking_captured:
thinking_content = parsed_content or ""
thinking_captured = True
if not tool_calls_parsed:
# No more tools - final answer
logger.info(f"✅ Final answer (after tool execution) after {iteration} iteration(s)")
return _create_response(parsed_content, [], "stop", prompt, request, swarm_manager, thinking_content)
# More tools detected - continue loop
logger.info(f"🔧 More tools found - continuing loop")
# Max iterations reached - force return last response
logger.warning(f"⚠️ Max tool iterations ({max_iterations}) reached")
logger.warning(f"⚠️ Returning last response (may include incomplete tool call)")
return _create_response(response_text, [], "stop", prompt, request, swarm_manager, thinking_content)