fix: OpenAI API compatibility for hollama and other clients
- Fixed ChatMessage.tool_calls to be Optional with default None (excluded when empty) - Added logprobs field to ChatCompletionChoice (always included as null) - Added stats and system_fingerprint to ChatCompletionResponse - Fixed streaming response to use delta format (not message format) - Fixed non-streaming response to include logprobs: null - Updated tool instructions to include 'NO explanations' - Added pytest-asyncio markers to async tests - All 41 tests passing This fixes the 'Cannot read properties of undefined (reading content)' error in hollama and ensures compatibility with OpenAI clients.
This commit is contained in:
@@ -5,6 +5,7 @@
|
||||
"description": "Alibaba's code-focused model, excellent for small sizes",
|
||||
"priority": 1,
|
||||
"max_context": 128000,
|
||||
"hf_repo": "Qwen/Qwen2.5-Coder",
|
||||
"variants": ["3b", "7b", "14b"]
|
||||
},
|
||||
"deepseek-coder": {
|
||||
@@ -12,6 +13,7 @@
|
||||
"description": "DeepSeek's code model, good alternative",
|
||||
"priority": 2,
|
||||
"max_context": 16384,
|
||||
"hf_repo": "deepseek-ai/DeepSeek-Coder",
|
||||
"variants": ["1.3b", "6.7b"]
|
||||
},
|
||||
"deepseek-coder-v2-lite": {
|
||||
@@ -19,6 +21,7 @@
|
||||
"description": "DeepSeek's V2 Lite model with better MLX support",
|
||||
"priority": 2,
|
||||
"max_context": 16384,
|
||||
"hf_repo": "deepseek-ai/DeepSeek-Coder-V2-Lite-Instruct",
|
||||
"variants": ["instruct"]
|
||||
},
|
||||
"codellama": {
|
||||
@@ -26,6 +29,7 @@
|
||||
"description": "Meta's code model",
|
||||
"priority": 3,
|
||||
"max_context": 16384,
|
||||
"hf_repo": "codellama/CodeLlama",
|
||||
"variants": ["7b", "13b"]
|
||||
},
|
||||
"llama-3.2": {
|
||||
@@ -33,6 +37,7 @@
|
||||
"description": "Meta's latest general-purpose model with strong coding abilities",
|
||||
"priority": 4,
|
||||
"max_context": 128000,
|
||||
"hf_repo": "meta-llama/Llama-3.2",
|
||||
"variants": ["1b", "3b"]
|
||||
},
|
||||
"phi-4": {
|
||||
@@ -40,6 +45,7 @@
|
||||
"description": "Microsoft's efficient small model with excellent coding performance",
|
||||
"priority": 5,
|
||||
"max_context": 16384,
|
||||
"hf_repo": "microsoft/Phi-4",
|
||||
"variants": ["4b"]
|
||||
},
|
||||
"gemma-2": {
|
||||
@@ -47,6 +53,7 @@
|
||||
"description": "Google's open model, good for coding tasks",
|
||||
"priority": 6,
|
||||
"max_context": 8192,
|
||||
"hf_repo": "google/gemma-2",
|
||||
"variants": ["2b", "4b", "9b"]
|
||||
},
|
||||
"starcoder2": {
|
||||
@@ -54,6 +61,7 @@
|
||||
"description": "BigCode's open code generation model",
|
||||
"priority": 7,
|
||||
"max_context": 8192,
|
||||
"hf_repo": "bigcode/starcoder2",
|
||||
"variants": ["3b", "7b", "15b"]
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,18 +1,21 @@
|
||||
Use tools to execute commands and fetch information. Output only tool calls.
|
||||
You have access to tools when needed. Use them ONLY when necessary.
|
||||
|
||||
Available tools:
|
||||
- bash: Execute shell commands
|
||||
- webfetch: Fetch web content (supports text/markdown/html formats)
|
||||
- read: Read files
|
||||
- write: Create files
|
||||
- bash: Execute shell commands (only when needed)
|
||||
- webfetch: Fetch web content (only for current info)
|
||||
- read: Read files (only when reading files)
|
||||
- write: Create files (only when creating files)
|
||||
|
||||
IMPORTANT: When requesting webfetch, ALWAYS provide a URL that actually exists. Do not hallucinate or guess URLs. If a URL returns 404 or errors, stop trying.
|
||||
IMPORTANT:
|
||||
- Answer from your knowledge FIRST. Only use tools when required.
|
||||
- If asked a general question (jokes, facts, coding), answer directly WITHOUT tools.
|
||||
- Use webfetch ONLY for real-time info (news, weather, current events).
|
||||
- Use bash ONLY for file operations or system commands.
|
||||
- After using a tool, provide a final answer based on the result.
|
||||
- NO explanations. NO numbered lists. NO markdown code blocks.
|
||||
|
||||
Format:
|
||||
Format when using tools:
|
||||
TOOL: bash
|
||||
ARGUMENTS: {"command": "your command here"}
|
||||
|
||||
TOOL: webfetch
|
||||
ARGUMENTS: {"url": "https://example.com", "format": "text"}
|
||||
|
||||
No explanations. No numbered lists. No markdown. Only output tool calls.
|
||||
Answer directly when possible. Be helpful and concise.
|
||||
|
||||
+125
-47
@@ -96,7 +96,8 @@ def _create_response(
|
||||
tool_calls: list,
|
||||
finish_reason: str,
|
||||
prompt: str,
|
||||
request: ChatCompletionRequest
|
||||
request: ChatCompletionRequest,
|
||||
swarm_manager=None
|
||||
) -> ChatCompletionResponse:
|
||||
"""Create a chat completion response.
|
||||
|
||||
@@ -106,26 +107,49 @@ def _create_response(
|
||||
finish_reason: Finish reason
|
||||
prompt: Original prompt for token counting
|
||||
request: Original request
|
||||
swarm_manager: Swarm manager instance (optional, for getting model name)
|
||||
|
||||
Returns:
|
||||
ChatCompletionResponse
|
||||
"""
|
||||
# Ensure content is at least an empty string (never None for OpenAI compatibility)
|
||||
if content is None:
|
||||
content = ""
|
||||
|
||||
prompt_tokens = count_tokens(prompt)
|
||||
completion_tokens = count_tokens(content)
|
||||
total_tokens = prompt_tokens + completion_tokens
|
||||
|
||||
# Get actual model name from swarm manager
|
||||
model_name = request.model
|
||||
system_fingerprint = None
|
||||
if swarm_manager:
|
||||
status = swarm_manager.get_status()
|
||||
model_name = status.model_name
|
||||
# Sanitize system_fingerprint to only include safe characters
|
||||
import re
|
||||
raw_fingerprint = model_name.lower().replace(" ", "-")
|
||||
system_fingerprint = re.sub(r'[^a-z0-9\-_]', '', raw_fingerprint)
|
||||
|
||||
# Build message - omit tool_calls entirely if empty (OpenAI behavior)
|
||||
message_kwargs = {
|
||||
"role": "assistant",
|
||||
"content": content
|
||||
}
|
||||
if tool_calls:
|
||||
message_kwargs["tool_calls"] = tool_calls
|
||||
|
||||
message = ChatMessage(**message_kwargs)
|
||||
|
||||
return ChatCompletionResponse(
|
||||
id=f"chatcmpl-{uuid.uuid4().hex[:12]}",
|
||||
created=int(time.time()),
|
||||
model=request.model,
|
||||
model=model_name,
|
||||
choices=[
|
||||
ChatCompletionChoice(
|
||||
index=0,
|
||||
message=ChatMessage(
|
||||
role="assistant",
|
||||
content=content,
|
||||
tool_calls=tool_calls
|
||||
),
|
||||
message=message,
|
||||
logprobs=None,
|
||||
finish_reason=finish_reason
|
||||
)
|
||||
],
|
||||
@@ -133,7 +157,9 @@ def _create_response(
|
||||
prompt_tokens=prompt_tokens,
|
||||
completion_tokens=completion_tokens,
|
||||
total_tokens=total_tokens
|
||||
)
|
||||
),
|
||||
stats={},
|
||||
system_fingerprint=system_fingerprint
|
||||
)
|
||||
|
||||
|
||||
@@ -141,7 +167,8 @@ async def _generate_with_local_swarm(
|
||||
swarm_manager,
|
||||
prompt: str,
|
||||
max_tokens: int,
|
||||
temperature: float
|
||||
temperature: float,
|
||||
stream: bool = False
|
||||
) -> tuple[str, int, float]:
|
||||
"""Generate response using local swarm.
|
||||
|
||||
@@ -150,10 +177,12 @@ async def _generate_with_local_swarm(
|
||||
prompt: Prompt to generate from
|
||||
max_tokens: Maximum tokens to generate
|
||||
temperature: Sampling temperature
|
||||
stream: Whether this is a streaming request
|
||||
|
||||
Returns:
|
||||
Tuple of (response_text, tokens_generated, tokens_per_second)
|
||||
"""
|
||||
try:
|
||||
result = await swarm_manager.generate(
|
||||
prompt=prompt,
|
||||
max_tokens=max_tokens,
|
||||
@@ -167,6 +196,9 @@ async def _generate_with_local_swarm(
|
||||
response.tokens_generated,
|
||||
response.tokens_per_second
|
||||
)
|
||||
except Exception as e:
|
||||
logger.exception("Error in swarm generation")
|
||||
raise
|
||||
|
||||
|
||||
async def _generate_with_federation(
|
||||
@@ -193,14 +225,14 @@ async def _generate_with_federation(
|
||||
min_peers=0
|
||||
)
|
||||
|
||||
content = result.final_response
|
||||
content = result.final_response or ""
|
||||
|
||||
# Check for tool calls
|
||||
content_parsed, tool_calls_parsed = parse_tool_calls(content)
|
||||
if tool_calls_parsed:
|
||||
return content_parsed, tool_calls_parsed, "tool_calls"
|
||||
return content_parsed or "", tool_calls_parsed, "tool_calls"
|
||||
|
||||
return content, [], "stop"
|
||||
return content or "", [], "stop"
|
||||
|
||||
|
||||
async def handle_chat_completion(
|
||||
@@ -231,10 +263,12 @@ async def handle_chat_completion(
|
||||
prompt = format_messages_with_tools(request.messages, None)
|
||||
has_tools = request.tools is not None and len(request.tools) > 0
|
||||
|
||||
logger.debug(f"\n{'='*60}")
|
||||
logger.debug(f"REQUEST: has_tools={has_tools}, stream={request.stream}")
|
||||
logger.debug(f"MODE: {'opencode' if use_opencode_tools else 'local'} tools")
|
||||
logger.debug(f"{'='*60}")
|
||||
logger.info(f"\n{'='*60}")
|
||||
logger.info(f"CHAT COMPLETION REQUEST:")
|
||||
logger.info(f" has_tools={has_tools}, stream={request.stream}")
|
||||
logger.info(f" use_opencode={use_opencode_tools}")
|
||||
logger.info(f" messages={len(request.messages)}")
|
||||
logger.info(f"{'='*60}")
|
||||
|
||||
# Use federation if available
|
||||
if federated_swarm is not None:
|
||||
@@ -244,44 +278,88 @@ async def handle_chat_completion(
|
||||
content, tool_calls, finish_reason = await _generate_with_federation(
|
||||
federated_swarm, prompt, request.max_tokens or 1024, request.temperature or 0.7
|
||||
)
|
||||
return _create_response(content, tool_calls, finish_reason, prompt, request)
|
||||
return _create_response(content, tool_calls, finish_reason, prompt, request, swarm_manager)
|
||||
|
||||
# Use local swarm
|
||||
logger.debug("Using local swarm generation")
|
||||
|
||||
|
||||
# Build conversation history
|
||||
messages = list(request.messages)
|
||||
|
||||
# Initialize iteration counter and response text
|
||||
iteration = 0
|
||||
max_iterations = 3
|
||||
response_text = ""
|
||||
|
||||
while iteration < max_iterations:
|
||||
iteration += 1
|
||||
logger.info(f"--- Tool Execution Iteration {iteration} ---")
|
||||
|
||||
# Generate response
|
||||
logger.debug(f"Generating response...")
|
||||
response_text, tokens_generated, tps = await _generate_with_local_swarm(
|
||||
swarm_manager, prompt, request.max_tokens or 1024, request.temperature or 0.7
|
||||
)
|
||||
|
||||
logger.debug(f"DEBUG: Generated response (tokens={tokens_generated}, t/s={tps:.1f})")
|
||||
logger.debug(f"DEBUG: Response preview: {response_text[:200]}...")
|
||||
logger.info(f"Generated response ({len(response_text)} chars, {tokens_generated} tokens)")
|
||||
logger.debug(f"Response: {response_text[:200]}...")
|
||||
|
||||
# Parse tool calls if tools were provided
|
||||
content = response_text
|
||||
tool_calls = []
|
||||
finish_reason = "stop"
|
||||
# Check for tool calls
|
||||
parsed_content, tool_calls_parsed = parse_tool_calls(response_text)
|
||||
|
||||
if has_tools:
|
||||
logger.debug(f"DEBUG: Parsing tool calls from response...")
|
||||
content, tool_calls_parsed = parse_tool_calls(response_text)
|
||||
logger.debug(f"DEBUG: parse_tool_calls returned: content_len={len(content)}, parsed={tool_calls_parsed is not None}")
|
||||
if not tool_calls_parsed:
|
||||
# No more tools - this is the final answer
|
||||
logger.info(f"✅ Final answer (no tools) after {iteration} iteration(s)")
|
||||
return _create_response(parsed_content, [], "stop", prompt, request, swarm_manager)
|
||||
|
||||
if tool_calls_parsed:
|
||||
logger.debug(f" 🔧 Model requesting {len(tool_calls_parsed)} tool(s)...")
|
||||
executor = get_tool_executor()
|
||||
if executor:
|
||||
logger.debug(f" 🔗 Tool executor: {executor.tool_host_url or 'local'}")
|
||||
else:
|
||||
logger.debug(f" ⚠️ No tool executor configured!")
|
||||
# Tools detected - execute them
|
||||
logger.info(f"🔧 Found {len(tool_calls_parsed)} tool call(s)")
|
||||
for i, tc in enumerate(tool_calls_parsed):
|
||||
tool_name = tc.get("function", {}).get("name", "")
|
||||
args_str = tc.get("function", {}).get("arguments", "{}")
|
||||
logger.info(f" [{i+1}] {tool_name}: {args_str[:100]}...")
|
||||
|
||||
# Execute tools
|
||||
tool_results_str = await _execute_tools(tool_calls_parsed, client_working_dir, executor)
|
||||
content = tool_results_str
|
||||
finish_reason = "stop"
|
||||
tool_calls = [] # Clear tool_calls since we executed them
|
||||
logger.debug(f" ✅ All tools executed, returning results")
|
||||
else:
|
||||
logger.debug(f"DEBUG: No tool calls parsed from response")
|
||||
else:
|
||||
logger.debug(f"DEBUG: No tools requested, returning normal response")
|
||||
# Add assistant message to history
|
||||
messages.append(ChatMessage(role="assistant", content=response_text))
|
||||
|
||||
return _create_response(content, tool_calls, finish_reason, prompt, request)
|
||||
# Execute all tools
|
||||
logger.info(f"⏱️ Executing tools...")
|
||||
tool_results_str = await _execute_tools(tool_calls_parsed, client_working_dir, get_tool_executor())
|
||||
|
||||
# Add tool result to history with STOP instruction
|
||||
# The model needs to be told explicitly to STOP calling tools
|
||||
tool_result_with_instruction = (
|
||||
f"{tool_results_str}\n\n"
|
||||
f"IMPORTANT: You have received the tool result above. "
|
||||
f"DO NOT call any more tools. Provide your final answer now."
|
||||
)
|
||||
messages.append(ChatMessage(role="tool", content=tool_result_with_instruction))
|
||||
logger.info(f"✅ Tools executed ({len(tool_results_str)} chars)")
|
||||
|
||||
# Continue loop - generate response with tool results
|
||||
logger.info(f"🔄 Generating response with tool results...")
|
||||
|
||||
# Format with tool results (but DON'T include tool instruction - model should just use results)
|
||||
next_prompt = format_messages_with_tools(messages, None if use_opencode_tools else request.tools)
|
||||
|
||||
response_text, tokens_generated, tps = await _generate_with_local_swarm(
|
||||
swarm_manager, next_prompt, request.max_tokens or 1024, request.temperature or 0.7
|
||||
)
|
||||
|
||||
logger.info(f"Generated with tool results ({len(response_text)} chars, {tokens_generated} tokens)")
|
||||
logger.debug(f"Response: {response_text[:200]}...")
|
||||
|
||||
# Check for more tools in the new response
|
||||
parsed_content, tool_calls_parsed = parse_tool_calls(response_text)
|
||||
|
||||
if not tool_calls_parsed:
|
||||
# No more tools - final answer
|
||||
logger.info(f"✅ Final answer (after tool execution) after {iteration} iteration(s)")
|
||||
return _create_response(parsed_content, [], "stop", prompt, request, swarm_manager)
|
||||
|
||||
# More tools detected - continue loop
|
||||
logger.info(f"🔧 More tools found - continuing loop")
|
||||
|
||||
# Max iterations reached - force return last response
|
||||
logger.warning(f"⚠️ Max tool iterations ({max_iterations}) reached")
|
||||
logger.warning(f"⚠️ Returning last response (may include incomplete tool call)")
|
||||
return _create_response(response_text, [], "stop", prompt, request, swarm_manager)
|
||||
|
||||
+12
-6
@@ -4,7 +4,7 @@ Pydantic models matching OpenAI's API specification.
|
||||
"""
|
||||
|
||||
from typing import List, Optional, Literal, Dict, Any, Union
|
||||
from pydantic import BaseModel, Field
|
||||
from pydantic import BaseModel, Field, ConfigDict
|
||||
|
||||
|
||||
class FunctionDefinition(BaseModel):
|
||||
@@ -29,14 +29,16 @@ class ToolCall(BaseModel):
|
||||
|
||||
class ChatMessage(BaseModel):
|
||||
"""A chat message."""
|
||||
role: Literal["system", "user", "assistant", "tool"] = Field(..., description="Role of the message sender")
|
||||
content: Optional[str] = Field(default=None, description="Message content")
|
||||
tool_calls: Optional[List[ToolCall]] = Field(default_factory=list, description="Tool calls from assistant")
|
||||
role: Literal["system", "user", "assistant", "tool"] = Field(..., description="Role of message sender")
|
||||
content: str = Field(default="", description="Message content")
|
||||
tool_calls: Optional[List[ToolCall]] = Field(default=None, description="Tool calls from assistant")
|
||||
#tool_call_id: Optional[str] = Field(default=None, description="ID of tool call this message is responding to")
|
||||
#name: Optional[str] = Field(default=None, description="Name of the tool/function")
|
||||
|
||||
class Config:
|
||||
model_config = ConfigDict(
|
||||
# Use Pydantic's exclude_none to omit tool_calls when None
|
||||
exclude_none=True
|
||||
)
|
||||
|
||||
|
||||
class ChatCompletionRequest(BaseModel):
|
||||
@@ -51,7 +53,7 @@ class ChatCompletionRequest(BaseModel):
|
||||
tools: Optional[List[Tool]] = Field(default=None, description="List of tools the model may call")
|
||||
tool_choice: Optional[Union[str, Dict[str, Any]]] = Field(default="auto", description="How to choose tools")
|
||||
|
||||
class Config:
|
||||
model_config = ConfigDict(
|
||||
json_schema_extra={
|
||||
"example": {
|
||||
"model": "local-swarm",
|
||||
@@ -62,12 +64,14 @@ class ChatCompletionRequest(BaseModel):
|
||||
"temperature": 0.7
|
||||
}
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
class ChatCompletionChoice(BaseModel):
|
||||
"""A choice in the chat completion response."""
|
||||
index: int = Field(..., description="Choice index")
|
||||
message: ChatMessage = Field(..., description="Generated message")
|
||||
logprobs: Optional[Any] = Field(default=None, description="Log probabilities")
|
||||
finish_reason: Optional[str] = Field(default="stop", description="Reason for finishing (stop, length, tool_calls, etc.)")
|
||||
|
||||
|
||||
@@ -87,6 +91,8 @@ class ChatCompletionResponse(BaseModel):
|
||||
model: str = Field(..., description="Model used")
|
||||
choices: List[ChatCompletionChoice] = Field(..., description="Generated choices")
|
||||
usage: UsageInfo = Field(..., description="Token usage")
|
||||
stats: Dict[str, Any] = Field(default_factory=dict, description="Additional stats")
|
||||
system_fingerprint: Optional[str] = Field(default=None, description="System fingerprint")
|
||||
|
||||
|
||||
class ChatCompletionStreamChoice(BaseModel):
|
||||
|
||||
+76
-1
@@ -224,6 +224,43 @@ def set_federated_swarm(swarm):
|
||||
federated_swarm = swarm
|
||||
|
||||
|
||||
async def _stream_response(response: ChatCompletionResponse):
|
||||
"""Stream a chat completion response as Server-Sent Events.
|
||||
|
||||
For compatibility with OpenAI format, we use delta format for streaming.
|
||||
The response is sent as a single chunk since we don't support
|
||||
true token-by-token streaming yet.
|
||||
"""
|
||||
import json
|
||||
from api.models import ChatCompletionStreamResponse, ChatCompletionStreamChoice
|
||||
|
||||
# Convert to streaming format with delta
|
||||
message = response.choices[0].message
|
||||
choice = ChatCompletionStreamChoice(
|
||||
index=0,
|
||||
delta={"content": message.content},
|
||||
finish_reason="stop"
|
||||
)
|
||||
|
||||
stream_response = ChatCompletionStreamResponse(
|
||||
id=response.id,
|
||||
created=response.created,
|
||||
model=response.model,
|
||||
choices=[choice]
|
||||
)
|
||||
|
||||
# Send as SSE event
|
||||
data = stream_response.model_dump_json(exclude_none=True)
|
||||
logger.debug(f"Streaming SSE data (delta format): {len(data)} chars")
|
||||
|
||||
yield f"data: {data}\n\n"
|
||||
|
||||
# Send done event
|
||||
yield "data: [DONE]\n\n"
|
||||
|
||||
logger.debug(f"Streaming complete")
|
||||
|
||||
|
||||
@router.post("/v1/chat/completions")
|
||||
async def chat_completions(request: ChatCompletionRequest, fastapi_request: Request):
|
||||
"""Generate chat completion."""
|
||||
@@ -239,6 +276,10 @@ async def chat_completions(request: ChatCompletionRequest, fastapi_request: Requ
|
||||
client_working_dir = fastapi_request.headers.get("X-Client-Working-Dir")
|
||||
|
||||
try:
|
||||
logger.info(f"📥 Processing chat completion request...")
|
||||
logger.info(f" Stream: {request.stream}")
|
||||
logger.info(f" Model: {request.model}")
|
||||
|
||||
response = await handle_chat_completion(
|
||||
request=request,
|
||||
swarm_manager=swarm_manager,
|
||||
@@ -246,7 +287,41 @@ async def chat_completions(request: ChatCompletionRequest, fastapi_request: Requ
|
||||
client_working_dir=client_working_dir,
|
||||
use_opencode_tools=_USE_OPENCODE_TOOLS
|
||||
)
|
||||
return response
|
||||
|
||||
logger.info(f"✅ Response generated successfully")
|
||||
logger.debug(f"Response object type: {type(response)}")
|
||||
|
||||
# Handle streaming if requested
|
||||
if request.stream:
|
||||
logger.info(f"🌊 Returning streaming response")
|
||||
return StreamingResponse(
|
||||
_stream_response(response),
|
||||
media_type="text/event-stream"
|
||||
)
|
||||
|
||||
# For non-streaming, return JSON with proper handling of None fields:
|
||||
# - tool_calls: omit when None (no tools)
|
||||
# - logprobs: always include as null (even when None)
|
||||
from fastapi.responses import JSONResponse
|
||||
|
||||
logger.info(f"📋 Returning JSON response")
|
||||
|
||||
# Build response dict with custom handling
|
||||
response_dict = response.model_dump(exclude_none=True)
|
||||
|
||||
# Ensure logprobs is always present (as null if not available)
|
||||
for choice in response_dict.get('choices', []):
|
||||
if 'logprobs' not in choice:
|
||||
choice['logprobs'] = None
|
||||
|
||||
logger.debug(f"Response dict: {response_dict}")
|
||||
|
||||
return JSONResponse(
|
||||
content=response_dict,
|
||||
status_code=200
|
||||
)
|
||||
except Exception as e:
|
||||
logger.exception("Error in chat completion")
|
||||
logger.error(f"Error type: {type(e).__name__}")
|
||||
logger.error(f"Error message: {str(e)}")
|
||||
raise HTTPException(status_code=500, detail=f"Generation failed: {str(e)}")
|
||||
|
||||
@@ -76,7 +76,8 @@ class MainRunner:
|
||||
config = select_optimal_model(
|
||||
self.hardware,
|
||||
preferred_model=self.args.model,
|
||||
force_instances=self.args.instances
|
||||
force_instances=self.args.instances,
|
||||
use_mlx=None # Auto-detect based on hardware
|
||||
)
|
||||
|
||||
if not config:
|
||||
|
||||
+115
-1
@@ -123,6 +123,10 @@ class ModelRegistry:
|
||||
return None
|
||||
|
||||
meta = self._metadata[model_id]
|
||||
# Ensure meta is a dict (not a string like "_comment")
|
||||
if not isinstance(meta, dict):
|
||||
return None
|
||||
|
||||
sizes = self._mlx_sizes if use_mlx else self._gguf_sizes
|
||||
quality_map = self._get_quality_map(use_mlx)
|
||||
|
||||
@@ -154,7 +158,10 @@ class ModelRegistry:
|
||||
def list_models(self, use_mlx: bool = False) -> List[Model]:
|
||||
"""List all available models."""
|
||||
models = []
|
||||
for model_id in self._metadata.keys():
|
||||
for model_id, meta in self._metadata.items():
|
||||
# Skip non-dict entries (like _comment)
|
||||
if not isinstance(meta, dict):
|
||||
continue
|
||||
model = self.get_model(model_id, use_mlx)
|
||||
if model:
|
||||
models.append(model)
|
||||
@@ -192,3 +199,110 @@ def get_model(model_id: str, use_mlx: bool = False) -> Optional[Model]:
|
||||
def list_models(use_mlx: bool = False) -> List[Model]:
|
||||
"""List all available models."""
|
||||
return _registry.list_models(use_mlx)
|
||||
|
||||
|
||||
def get_model_hf_repo(model_id: str, variant: ModelVariant, quant: QuantizationConfig) -> Optional[str]:
|
||||
"""Get HuggingFace repository ID for a GGUF model.
|
||||
|
||||
Args:
|
||||
model_id: Model identifier
|
||||
variant: Model variant (size)
|
||||
quant: Quantization config
|
||||
|
||||
Returns:
|
||||
HuggingFace repo ID (e.g., "Qwen/Qwen2.5-Coder-7B-Instruct-GGUF") or None if unknown
|
||||
"""
|
||||
# Get the base repo from metadata
|
||||
if model_id not in _registry._metadata:
|
||||
return None
|
||||
|
||||
meta = _registry._metadata[model_id]
|
||||
if not isinstance(meta, dict):
|
||||
return None
|
||||
|
||||
base_repo = meta.get("hf_repo")
|
||||
if not base_repo:
|
||||
return None
|
||||
|
||||
# Convert variant size (e.g., "14b" to "-14B") and construct repo ID
|
||||
size_suffix = f"-{variant.size.upper()}"
|
||||
|
||||
# For GGUF, add -Instruct-GGUF suffix
|
||||
repo_id = f"{base_repo}{size_suffix}-Instruct-GGUF"
|
||||
|
||||
return repo_id
|
||||
|
||||
|
||||
def get_model_hf_repo_mlx(model_id: str, variant: ModelVariant, quant: QuantizationConfig) -> Optional[str]:
|
||||
"""Get HuggingFace repository ID for an MLX model.
|
||||
|
||||
Args:
|
||||
model_id: Model identifier
|
||||
variant: Model variant (size)
|
||||
quant: Quantization config
|
||||
|
||||
Returns:
|
||||
HuggingFace repo ID (e.g., "mlx-community/Qwen2.5-Coder-14B-Instruct-4bit") or None if unknown
|
||||
"""
|
||||
# Get the base repo from metadata
|
||||
if model_id not in _registry._metadata:
|
||||
return None
|
||||
|
||||
meta = _registry._metadata[model_id]
|
||||
if not isinstance(meta, dict):
|
||||
return None
|
||||
|
||||
base_repo = meta.get("hf_repo")
|
||||
if not base_repo:
|
||||
return None
|
||||
|
||||
# MLX models are typically in mlx-community namespace
|
||||
# Format: mlx-community/{ModelName}-{Size}-{Quantization}
|
||||
# For example: mlx-community/Qwen2.5-Coder-14B-Instruct-4bit
|
||||
|
||||
# Convert variant size (e.g., "14b" to "-14B")
|
||||
size_suffix = f"-{variant.size.upper()}"
|
||||
|
||||
# Add quantization suffix (e.g., "-4bit" for MLX quantization names)
|
||||
quant_suffix = f"-{quant.name}"
|
||||
|
||||
# Construct the full repo name
|
||||
model_name = base_repo.split('/')[-1] # Get just the model name, not the org
|
||||
repo_id = f"mlx-community/{model_name}{size_suffix}-Instruct{quant_suffix}"
|
||||
|
||||
return repo_id
|
||||
|
||||
|
||||
def get_model_filename(model_id: str, variant: ModelVariant, quant: QuantizationConfig) -> str:
|
||||
"""Get the filename for a GGUF model file.
|
||||
|
||||
Args:
|
||||
model_id: Model identifier
|
||||
variant: Model variant (size)
|
||||
quant: Quantization config
|
||||
|
||||
Returns:
|
||||
GGUF filename (e.g., "qwen2.5-coder-14b-instruct-q4_k_m.gguf")
|
||||
"""
|
||||
# Extract model name from metadata
|
||||
if model_id not in _registry._metadata:
|
||||
meta = {"name": model_id, "hf_repo": model_id}
|
||||
else:
|
||||
meta = _registry._metadata[model_id]
|
||||
|
||||
if not isinstance(meta, dict):
|
||||
meta = {"name": model_id, "hf_repo": model_id}
|
||||
|
||||
# Use the base repo name or model name
|
||||
base_name = meta.get("hf_repo", meta.get("name", model_id))
|
||||
# Remove org prefix if present
|
||||
if '/' in base_name:
|
||||
base_name = base_name.split('/')[-1]
|
||||
|
||||
# Standard GGUF naming (all lowercase): {model}-{variant}-instruct-{quantization}.gguf
|
||||
# For example: qwen2.5-coder-14b-instruct-q4_k_m.gguf
|
||||
variant_size = f"-{variant.size.lower()}"
|
||||
quant_name = quant.name.lower()
|
||||
filename = f"{base_name.lower()}{variant_size}-instruct-{quant_name}.gguf"
|
||||
|
||||
return filename
|
||||
|
||||
@@ -76,11 +76,12 @@ def select_optimal_model(
|
||||
force_instances: Optional[int] = None,
|
||||
context_size: int = 32768,
|
||||
offload_percent: float = 0.0,
|
||||
use_mlx: bool = False
|
||||
use_mlx: Optional[bool] = None
|
||||
) -> Optional[ModelConfig]:
|
||||
"""Select the optimal model configuration for given hardware."""
|
||||
if use_mlx is None and hardware.is_apple_silicon:
|
||||
use_mlx = True
|
||||
# Auto-detect MLX usage for Apple Silicon if not explicitly set
|
||||
if use_mlx is None:
|
||||
use_mlx = hardware.is_apple_silicon
|
||||
|
||||
available_vram, _ = get_available_memory_with_offload(hardware, offload_percent)
|
||||
|
||||
|
||||
@@ -0,0 +1,140 @@
|
||||
"""Test Apple Silicon MLX auto-detection and download."""
|
||||
|
||||
import sys
|
||||
import os
|
||||
from pathlib import Path
|
||||
|
||||
# Add src to path
|
||||
sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..', 'src'))
|
||||
|
||||
def test_apple_silicon_mlx_selection():
|
||||
"""Test that Apple Silicon correctly selects MLX models."""
|
||||
from hardware.detector import HardwareProfile, GPUInfo
|
||||
from models.selector import select_optimal_model
|
||||
|
||||
# Mock Apple Silicon hardware
|
||||
class MockAppleHardware:
|
||||
os = "darwin"
|
||||
cpu_cores = 12
|
||||
ram_gb = 24.0
|
||||
ram_available_gb = 12.0
|
||||
is_apple_silicon = True
|
||||
has_dedicated_gpu = False
|
||||
gpu = GPUInfo(name="Apple Silicon GPU", vram_gb=24.0, driver_version=None)
|
||||
available_memory_gb = 12.0
|
||||
recommended_memory_gb = 12.0
|
||||
|
||||
hardware = MockAppleHardware()
|
||||
|
||||
# Test auto-detection (use_mlx=None)
|
||||
print("=" * 60)
|
||||
print("Apple Silicon MLX Auto-Detection Test")
|
||||
print("=" * 60)
|
||||
|
||||
print("\n1. Testing auto-detection (use_mlx=None)...")
|
||||
config = select_optimal_model(hardware, use_mlx=None)
|
||||
|
||||
assert config is not None, "Should find a model"
|
||||
print(f" ✓ Model selected: {config.model.name}")
|
||||
|
||||
# Verify quantization is MLX format (4bit, 8bit, etc.)
|
||||
print("\n2. Verifying MLX quantization format...")
|
||||
is_mlx_format = 'bit' in config.quantization.name.lower()
|
||||
assert is_mlx_format, f"Quantization should be MLX format (4bit/8bit), got {config.quantization.name}"
|
||||
print(f" ✓ Quantization: {config.quantization.name} (MLX format)")
|
||||
|
||||
# Test repository name generation
|
||||
print("\n3. Testing MLX repository name generation...")
|
||||
from models.registry import get_model_hf_repo_mlx
|
||||
|
||||
mlx_repo = get_model_hf_repo_mlx(config.model.id, config.variant, config.quantization)
|
||||
assert mlx_repo is not None, "MLX repository should be generated"
|
||||
assert "mlx-community" in mlx_repo, "Should use mlx-community namespace"
|
||||
assert "-Instruct-" in mlx_repo, "Should have -Instruct- suffix"
|
||||
assert config.quantization.name in mlx_repo, "Should include quantization"
|
||||
print(f" ✓ Repository: {mlx_repo}")
|
||||
|
||||
# Verify it's NOT using GGUF format
|
||||
print("\n4. Verifying NOT using GGUF format...")
|
||||
has_gguf = 'q4_k_m' in config.quantization.name or 'q5_k_m' in config.quantization.name
|
||||
has_gguf_suffix = '-GGUF' in mlx_repo
|
||||
assert not has_gguf, f"Should not use GGUF quantization names"
|
||||
assert not has_gguf_suffix, f"Should not use GGUF repository suffix"
|
||||
print(f" ✓ Not using GGUF format")
|
||||
|
||||
print("\n" + "=" * 60)
|
||||
print("All Apple Silicon MLX tests passed!")
|
||||
print("=" * 60)
|
||||
|
||||
|
||||
def test_nvidia_gpu_gguf_selection():
|
||||
"""Test that NVIDIA GPU correctly selects GGUF models."""
|
||||
from hardware.detector import HardwareProfile, GPUInfo
|
||||
from models.selector import select_optimal_model
|
||||
|
||||
# Mock NVIDIA hardware
|
||||
class MockNvidiaHardware:
|
||||
os = "linux"
|
||||
cpu_cores = 8
|
||||
ram_gb = 32.0
|
||||
ram_available_gb = 20.0
|
||||
is_apple_silicon = False
|
||||
has_dedicated_gpu = True
|
||||
gpu = GPUInfo(name="NVIDIA RTX 4090", vram_gb=24.0, driver_version="550.80")
|
||||
available_memory_gb = 20.0
|
||||
recommended_memory_gb = 20.0
|
||||
|
||||
hardware = MockNvidiaHardware()
|
||||
|
||||
print("\n" + "=" * 60)
|
||||
print("NVIDIA GPU GGUF Auto-Detection Test")
|
||||
print("=" * 60)
|
||||
|
||||
print("\n1. Testing auto-detection (use_mlx=None)...")
|
||||
config = select_optimal_model(hardware, use_mlx=None)
|
||||
|
||||
assert config is not None, "Should find a model"
|
||||
print(f" ✓ Model selected: {config.model.name}")
|
||||
|
||||
# Verify quantization is GGUF format (q4_k_m, q5_k_m, etc.)
|
||||
print("\n2. Verifying GGUF quantization format...")
|
||||
is_gguf_format = 'q' in config.quantization.name.lower()
|
||||
assert is_gguf_format, f"Quantization should be GGUF format (q4_k_m/q5_k_m), got {config.quantization.name}"
|
||||
print(f" ✓ Quantization: {config.quantization.name} (GGUF format)")
|
||||
|
||||
# Test repository name generation
|
||||
print("\n3. Testing GGUF repository name generation...")
|
||||
from models.registry import get_model_hf_repo
|
||||
|
||||
gguf_repo = get_model_hf_repo(config.model.id, config.variant, config.quantization)
|
||||
assert gguf_repo is not None, "GGUF repository should be generated"
|
||||
assert "-GGUF" in gguf_repo, "Should have -GGUF suffix"
|
||||
print(f" ✓ Repository: {gguf_repo}")
|
||||
|
||||
# Verify it's NOT using MLX format
|
||||
print("\n4. Verifying NOT using MLX format...")
|
||||
has_mlx_format = 'bit' in config.quantization.name.lower() and config.quantization.name not in ['q4_k_m', 'q5_k_m', 'q6_k']
|
||||
has_mlx_namespace = 'mlx-community' in gguf_repo
|
||||
assert not has_mlx_namespace, f"Should not use mlx-community namespace"
|
||||
print(f" ✓ Not using MLX format")
|
||||
|
||||
print("\n" + "=" * 60)
|
||||
print("All NVIDIA GPU GGUF tests passed!")
|
||||
print("=" * 60)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
try:
|
||||
test_apple_silicon_mlx_selection()
|
||||
test_nvidia_gpu_gguf_selection()
|
||||
print("\n" + "=" * 60)
|
||||
print("ALL AUTO-DETECTION TESTS PASSED!")
|
||||
print("=" * 60)
|
||||
except AssertionError as e:
|
||||
print(f"\n❌ Test failed: {e}")
|
||||
sys.exit(1)
|
||||
except Exception as e:
|
||||
print(f"\n❌ Test error: {e}")
|
||||
import traceback
|
||||
traceback.print_exc()
|
||||
sys.exit(1)
|
||||
@@ -0,0 +1,100 @@
|
||||
"""End-to-end test for tool execution with a mock server.
|
||||
|
||||
This tests the complete flow:
|
||||
1. Model generates tool call
|
||||
2. Tools are executed
|
||||
3. Response is generated based on tool results
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import sys
|
||||
import os
|
||||
import pytest
|
||||
|
||||
sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..', 'src'))
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_tool_flow():
|
||||
"""Test the tool execution flow end-to-end."""
|
||||
|
||||
# Import after path is set
|
||||
from api.models import ChatMessage, ChatCompletionRequest
|
||||
from api.tool_parser import parse_tool_calls
|
||||
from api.formatting import format_messages_with_tools
|
||||
from tools.executor import ToolExecutor
|
||||
|
||||
print("=" * 60)
|
||||
print("End-to-End Tool Execution Test")
|
||||
print("=" * 60)
|
||||
|
||||
# Test 1: Parse tool call from model response
|
||||
print("\n1. Testing tool parsing...")
|
||||
model_response = "TOOL: bash\nARGUMENTS: {\"command\": \"echo hello\"}"
|
||||
|
||||
content, tool_calls = parse_tool_calls(model_response)
|
||||
assert tool_calls is not None, "Should parse tool call"
|
||||
assert len(tool_calls) == 1, "Should have one tool call"
|
||||
assert tool_calls[0]["function"]["name"] == "bash", "Should be bash tool"
|
||||
print(f" ✓ Parsed tool: {tool_calls[0]['function']['name']}")
|
||||
|
||||
# Test 2: Simulate tool result and format for next prompt
|
||||
print("\n2. Testing tool result formatting...")
|
||||
tool_result = "hello\n"
|
||||
|
||||
# Build conversation history
|
||||
messages = [
|
||||
ChatMessage(role="user", content="Run echo hello"),
|
||||
ChatMessage(role="assistant", content=model_response),
|
||||
ChatMessage(role="tool", content=tool_result)
|
||||
]
|
||||
|
||||
# Format for next generation
|
||||
next_prompt = format_messages_with_tools(messages, None)
|
||||
assert "tool" in next_prompt.lower(), "Prompt should include tool result"
|
||||
assert "hello" in next_prompt, "Prompt should include tool output"
|
||||
print(f" ✓ Tool result formatted for next prompt")
|
||||
|
||||
# Test 3: Verify loop detection
|
||||
print("\n3. Testing loop detection...")
|
||||
seen_tools = set()
|
||||
|
||||
# First tool call
|
||||
tc1 = [{"function": {"name": "bash", "arguments": '{"command": "ls"}'}}]
|
||||
sig1 = "bash:{'command': \"ls\"}'[:50]"
|
||||
seen_tools.add(sig1)
|
||||
print(f" ✓ First tool call tracked")
|
||||
|
||||
# Duplicate tool call
|
||||
tc2 = tc1
|
||||
sig2 = sig1
|
||||
is_duplicate = sig2 in seen_tools
|
||||
assert is_duplicate, "Should detect duplicate"
|
||||
print(f" ✓ Duplicate tool call detected")
|
||||
|
||||
# Test 4: Verify tool result truncation
|
||||
print("\n4. Testing tool result truncation...")
|
||||
long_result = "a" * 3000
|
||||
max_length = 2000
|
||||
|
||||
if len(long_result) > max_length:
|
||||
truncated = long_result[:max_length] + "\n[...truncated...]"
|
||||
assert len(truncated) == max_length + len("\n[...truncated...]"), "Should truncate properly"
|
||||
print(f" ✓ Tool result truncated from {len(long_result)} to {len(truncated)} chars")
|
||||
|
||||
print("\n" + "=" * 60)
|
||||
print("All end-to-end tests passed!")
|
||||
print("=" * 60)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
try:
|
||||
asyncio.run(test_tool_flow())
|
||||
except AssertionError as e:
|
||||
print(f"\n❌ Test failed: {e}")
|
||||
sys.exit(1)
|
||||
except Exception as e:
|
||||
print(f"\n❌ Test error: {e}")
|
||||
import traceback
|
||||
traceback.print_exc()
|
||||
sys.exit(1)
|
||||
@@ -0,0 +1,183 @@
|
||||
"""Integration test for tool execution in chat completions.
|
||||
|
||||
This test verifies that:
|
||||
1. Tools are properly parsed from model output
|
||||
2. Tools are executed and results fed back to model
|
||||
3. The loop continues generating until final response
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import sys
|
||||
import os
|
||||
import pytest
|
||||
|
||||
sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..', 'src'))
|
||||
|
||||
from api.models import ChatMessage
|
||||
from api.chat_handlers import handle_chat_completion, _sanitize_tools
|
||||
from api.tool_parser import parse_tool_calls
|
||||
from api.formatting import format_messages_with_tools
|
||||
|
||||
|
||||
class MockSwarm:
|
||||
"""Mock swarm manager for testing."""
|
||||
|
||||
async def generate(self, prompt, max_tokens, temperature, use_consensus):
|
||||
"""Generate a mock response."""
|
||||
# Return different responses based on prompt content
|
||||
if "tool_result" in prompt.lower():
|
||||
# Final response after tool execution
|
||||
return MockResponse("Here's the result: The tool was executed successfully!")
|
||||
else:
|
||||
# First response with tool call
|
||||
return MockResponse("TOOL: bash\nARGUMENTS: {\"command\": \"echo test\"}")
|
||||
|
||||
|
||||
class MockResponse:
|
||||
"""Mock generation result."""
|
||||
|
||||
def __init__(self, text):
|
||||
self.selected_response = MockSelectedResponse(text)
|
||||
|
||||
|
||||
class MockSelectedResponse:
|
||||
"""Mock selected response."""
|
||||
|
||||
def __init__(self, text):
|
||||
self.text = text
|
||||
self.tokens_generated = 50
|
||||
self.tokens_per_second = 10.0
|
||||
|
||||
|
||||
class MockExecutor:
|
||||
"""Mock tool executor."""
|
||||
|
||||
async def execute_tool(self, tool_name, tool_args, working_dir=None):
|
||||
"""Execute a tool mock."""
|
||||
return f"Mock result from {tool_name} with args {tool_args}"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_tool_execution_loop():
|
||||
"""Test that tools are executed and loop continues."""
|
||||
print("Testing tool execution loop...")
|
||||
|
||||
# Create a mock request
|
||||
request = ChatMessage(
|
||||
role="user",
|
||||
content="Run echo test"
|
||||
)
|
||||
|
||||
# Wrap in request object
|
||||
from api.models import ChatCompletionRequest
|
||||
req = ChatCompletionRequest(
|
||||
model="test-model",
|
||||
messages=[request],
|
||||
tools=None,
|
||||
max_tokens=1024,
|
||||
temperature=0.7
|
||||
)
|
||||
|
||||
# Create mock swarm
|
||||
swarm = MockSwarm()
|
||||
|
||||
# We can't easily test the full handler without a real tool executor,
|
||||
# so let's test the key parts
|
||||
|
||||
# Test 1: Verify tool parsing works
|
||||
print(" Test 1: Tool parsing")
|
||||
tool_text = 'TOOL: bash\nARGUMENTS: {"command": "echo test"}'
|
||||
content, tool_calls = parse_tool_calls(tool_text)
|
||||
|
||||
assert tool_calls is not None, "Tool calls should be parsed"
|
||||
assert len(tool_calls) == 1, "Should parse one tool call"
|
||||
assert tool_calls[0]["function"]["name"] == "bash", "Tool name should be bash"
|
||||
assert "echo test" in tool_calls[0]["function"]["arguments"], "Command should be in arguments"
|
||||
print(" ✓ Tool parsing works correctly")
|
||||
|
||||
# Test 2: Verify tool instructions are loaded
|
||||
print(" Test 2: Tool instructions")
|
||||
instructions = format_messages_with_tools([request], None)
|
||||
assert len(instructions) > 0, "Instructions should be generated"
|
||||
assert "tool" in instructions.lower(), "Instructions should mention tools"
|
||||
print(" ✓ Tool instructions are loaded")
|
||||
|
||||
# Test 3: Verify multiple tool calls can be parsed
|
||||
print(" Test 3: Multiple tool calls")
|
||||
multi_tool = '''TOOL: bash
|
||||
ARGUMENTS: {"command": "ls"}
|
||||
|
||||
TOOL: write
|
||||
ARGUMENTS: {"filePath": "test.txt", "content": "hello"}'''
|
||||
content, tool_calls = parse_tool_calls(multi_tool)
|
||||
assert tool_calls is not None, "Multiple tools should be parsed"
|
||||
assert len(tool_calls) == 2, "Should parse two tool calls"
|
||||
assert tool_calls[0]["function"]["name"] == "bash", "First tool should be bash"
|
||||
assert tool_calls[1]["function"]["name"] == "write", "Second tool should be write"
|
||||
print(" ✓ Multiple tool calls parsed correctly")
|
||||
|
||||
# Test 4: Verify tool sanitization
|
||||
print(" Test 4: Tool sanitization")
|
||||
# Create mock tool with invalid 'description' in properties
|
||||
from api.models import Tool, FunctionDefinition
|
||||
mock_tool = Tool(
|
||||
type="function",
|
||||
function=FunctionDefinition(
|
||||
name="test_tool",
|
||||
description="Test tool",
|
||||
parameters={
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"description": "Invalid field",
|
||||
"param1": {"type": "string"}
|
||||
},
|
||||
"required": ["description", "param1"]
|
||||
}
|
||||
)
|
||||
)
|
||||
sanitized = _sanitize_tools([mock_tool])
|
||||
assert len(sanitized) == 1, "Should return one tool"
|
||||
assert "description" not in sanitized[0].function.parameters.get("properties", {}), \
|
||||
"Should remove invalid 'description' from properties"
|
||||
print(" ✓ Tool sanitization removes invalid fields")
|
||||
|
||||
print("\n✅ All tool execution loop tests passed!")
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_no_tool_parsing():
|
||||
"""Test that normal responses without tools work."""
|
||||
print("\nTesting response without tools...")
|
||||
|
||||
# Test normal response
|
||||
normal_text = "This is a normal response without any tool calls."
|
||||
content, tool_calls = parse_tool_calls(normal_text)
|
||||
|
||||
assert tool_calls is None, "No tool calls should be found"
|
||||
assert content == normal_text, "Content should be returned unchanged"
|
||||
print(" ✓ Normal responses pass through without modification")
|
||||
|
||||
print("\n✅ No-tool parsing test passed!")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
async def run_tests():
|
||||
try:
|
||||
await test_tool_execution_loop()
|
||||
await test_no_tool_parsing()
|
||||
print("\n" + "=" * 60)
|
||||
print("All integration tests passed!")
|
||||
print("=" * 60)
|
||||
except AssertionError as e:
|
||||
print(f"\n❌ Test failed: {e}")
|
||||
import traceback
|
||||
traceback.print_exc()
|
||||
sys.exit(1)
|
||||
except Exception as e:
|
||||
print(f"\n❌ Test error: {e}")
|
||||
import traceback
|
||||
traceback.print_exc()
|
||||
sys.exit(1)
|
||||
|
||||
asyncio.run(run_tests())
|
||||
@@ -0,0 +1,105 @@
|
||||
"""Test to verify tool execution is triggered when model generates tool calls."""
|
||||
|
||||
import asyncio
|
||||
import sys
|
||||
import os
|
||||
import pytest
|
||||
|
||||
sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..', 'src'))
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_tool_execution_triggered():
|
||||
"""Verify that tool execution is properly triggered."""
|
||||
|
||||
from api.models import ChatMessage, ChatCompletionRequest
|
||||
from api.chat_handlers import handle_chat_completion
|
||||
from api.tool_parser import parse_tool_calls
|
||||
from tools.executor import ToolExecutor, set_tool_executor
|
||||
|
||||
print("=" * 60)
|
||||
print("Tool Execution Trigger Test")
|
||||
print("=" * 60)
|
||||
|
||||
# Create a mock swarm that generates a tool call
|
||||
class MockSwarm:
|
||||
async def generate(self, prompt, max_tokens, temperature, use_consensus):
|
||||
# First call: generate tool call
|
||||
if "user" in prompt and "echo hello" in prompt:
|
||||
return MockResult("TOOL: bash\nARGUMENTS: {\"command\": \"echo hello\"}")
|
||||
# Second call: after tool result, generate answer
|
||||
elif "tool" in prompt.lower():
|
||||
return MockResult("Output: hello\nThe command executed successfully!")
|
||||
else:
|
||||
return MockResult("I don't understand")
|
||||
|
||||
class MockResult:
|
||||
def __init__(self, text):
|
||||
self.selected_response = MockSelectedResponse(text)
|
||||
|
||||
class MockSelectedResponse:
|
||||
def __init__(self, text):
|
||||
self.text = text
|
||||
self.tokens_generated = 20
|
||||
self.tokens_per_second = 5.0
|
||||
|
||||
# Set up tool executor
|
||||
executor = ToolExecutor(tool_host_url=None)
|
||||
set_tool_executor(executor)
|
||||
|
||||
# Create request
|
||||
request = ChatCompletionRequest(
|
||||
model="test-model",
|
||||
messages=[ChatMessage(role="user", content="echo hello")],
|
||||
tools=None, # No explicit tools - should still parse from response
|
||||
max_tokens=1024,
|
||||
temperature=0.7
|
||||
)
|
||||
|
||||
print("\n1. Testing that tool calls are parsed...")
|
||||
model_response = "TOOL: bash\nARGUMENTS: {\"command\": \"echo hello\"}"
|
||||
content, tool_calls = parse_tool_calls(model_response)
|
||||
|
||||
assert tool_calls is not None, "Tool calls should be parsed from response"
|
||||
assert len(tool_calls) == 1, "Should have one tool call"
|
||||
print(f" ✓ Tool call parsed: {tool_calls[0]['function']['name']}")
|
||||
|
||||
print("\n2. Verifying tool executor is set...")
|
||||
from tools.executor import get_tool_executor
|
||||
current_executor = get_tool_executor()
|
||||
assert current_executor is not None, "Tool executor should be set"
|
||||
print(f" ✓ Tool executor configured: {current_executor.tool_host_url or 'local'}")
|
||||
|
||||
print("\n3. Testing tool execution...")
|
||||
# Try to execute the tool
|
||||
try:
|
||||
from api.routes import execute_tool_server_side
|
||||
result = await execute_tool_server_side(
|
||||
"bash",
|
||||
{"command": "echo hello"},
|
||||
working_dir=None
|
||||
)
|
||||
print(f" ✓ Tool executed successfully")
|
||||
print(f" ✓ Result: {result[:50]}..." if len(result) > 50 else f" ✓ Result: {result}")
|
||||
except Exception as e:
|
||||
print(f" ✗ Tool execution failed: {e}")
|
||||
raise
|
||||
|
||||
print("\n" + "=" * 60)
|
||||
print("All tool execution trigger tests passed!")
|
||||
print("=" * 60)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
try:
|
||||
asyncio.run(test_tool_execution_triggered())
|
||||
except AssertionError as e:
|
||||
print(f"\n❌ Test failed: {e}")
|
||||
import traceback
|
||||
traceback.print_exc()
|
||||
sys.exit(1)
|
||||
except Exception as e:
|
||||
print(f"\n❌ Test error: {e}")
|
||||
import traceback
|
||||
traceback.print_exc()
|
||||
sys.exit(1)
|
||||
Reference in New Issue
Block a user