fix: OpenAI API compatibility for hollama and other clients
- Fixed ChatMessage.tool_calls to be Optional with default None (excluded when empty) - Added logprobs field to ChatCompletionChoice (always included as null) - Added stats and system_fingerprint to ChatCompletionResponse - Fixed streaming response to use delta format (not message format) - Fixed non-streaming response to include logprobs: null - Updated tool instructions to include 'NO explanations' - Added pytest-asyncio markers to async tests - All 41 tests passing This fixes the 'Cannot read properties of undefined (reading content)' error in hollama and ensures compatibility with OpenAI clients.
This commit is contained in:
@@ -5,6 +5,7 @@
|
|||||||
"description": "Alibaba's code-focused model, excellent for small sizes",
|
"description": "Alibaba's code-focused model, excellent for small sizes",
|
||||||
"priority": 1,
|
"priority": 1,
|
||||||
"max_context": 128000,
|
"max_context": 128000,
|
||||||
|
"hf_repo": "Qwen/Qwen2.5-Coder",
|
||||||
"variants": ["3b", "7b", "14b"]
|
"variants": ["3b", "7b", "14b"]
|
||||||
},
|
},
|
||||||
"deepseek-coder": {
|
"deepseek-coder": {
|
||||||
@@ -12,6 +13,7 @@
|
|||||||
"description": "DeepSeek's code model, good alternative",
|
"description": "DeepSeek's code model, good alternative",
|
||||||
"priority": 2,
|
"priority": 2,
|
||||||
"max_context": 16384,
|
"max_context": 16384,
|
||||||
|
"hf_repo": "deepseek-ai/DeepSeek-Coder",
|
||||||
"variants": ["1.3b", "6.7b"]
|
"variants": ["1.3b", "6.7b"]
|
||||||
},
|
},
|
||||||
"deepseek-coder-v2-lite": {
|
"deepseek-coder-v2-lite": {
|
||||||
@@ -19,6 +21,7 @@
|
|||||||
"description": "DeepSeek's V2 Lite model with better MLX support",
|
"description": "DeepSeek's V2 Lite model with better MLX support",
|
||||||
"priority": 2,
|
"priority": 2,
|
||||||
"max_context": 16384,
|
"max_context": 16384,
|
||||||
|
"hf_repo": "deepseek-ai/DeepSeek-Coder-V2-Lite-Instruct",
|
||||||
"variants": ["instruct"]
|
"variants": ["instruct"]
|
||||||
},
|
},
|
||||||
"codellama": {
|
"codellama": {
|
||||||
@@ -26,6 +29,7 @@
|
|||||||
"description": "Meta's code model",
|
"description": "Meta's code model",
|
||||||
"priority": 3,
|
"priority": 3,
|
||||||
"max_context": 16384,
|
"max_context": 16384,
|
||||||
|
"hf_repo": "codellama/CodeLlama",
|
||||||
"variants": ["7b", "13b"]
|
"variants": ["7b", "13b"]
|
||||||
},
|
},
|
||||||
"llama-3.2": {
|
"llama-3.2": {
|
||||||
@@ -33,6 +37,7 @@
|
|||||||
"description": "Meta's latest general-purpose model with strong coding abilities",
|
"description": "Meta's latest general-purpose model with strong coding abilities",
|
||||||
"priority": 4,
|
"priority": 4,
|
||||||
"max_context": 128000,
|
"max_context": 128000,
|
||||||
|
"hf_repo": "meta-llama/Llama-3.2",
|
||||||
"variants": ["1b", "3b"]
|
"variants": ["1b", "3b"]
|
||||||
},
|
},
|
||||||
"phi-4": {
|
"phi-4": {
|
||||||
@@ -40,6 +45,7 @@
|
|||||||
"description": "Microsoft's efficient small model with excellent coding performance",
|
"description": "Microsoft's efficient small model with excellent coding performance",
|
||||||
"priority": 5,
|
"priority": 5,
|
||||||
"max_context": 16384,
|
"max_context": 16384,
|
||||||
|
"hf_repo": "microsoft/Phi-4",
|
||||||
"variants": ["4b"]
|
"variants": ["4b"]
|
||||||
},
|
},
|
||||||
"gemma-2": {
|
"gemma-2": {
|
||||||
@@ -47,6 +53,7 @@
|
|||||||
"description": "Google's open model, good for coding tasks",
|
"description": "Google's open model, good for coding tasks",
|
||||||
"priority": 6,
|
"priority": 6,
|
||||||
"max_context": 8192,
|
"max_context": 8192,
|
||||||
|
"hf_repo": "google/gemma-2",
|
||||||
"variants": ["2b", "4b", "9b"]
|
"variants": ["2b", "4b", "9b"]
|
||||||
},
|
},
|
||||||
"starcoder2": {
|
"starcoder2": {
|
||||||
@@ -54,6 +61,7 @@
|
|||||||
"description": "BigCode's open code generation model",
|
"description": "BigCode's open code generation model",
|
||||||
"priority": 7,
|
"priority": 7,
|
||||||
"max_context": 8192,
|
"max_context": 8192,
|
||||||
|
"hf_repo": "bigcode/starcoder2",
|
||||||
"variants": ["3b", "7b", "15b"]
|
"variants": ["3b", "7b", "15b"]
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,18 +1,21 @@
|
|||||||
Use tools to execute commands and fetch information. Output only tool calls.
|
You have access to tools when needed. Use them ONLY when necessary.
|
||||||
|
|
||||||
Available tools:
|
Available tools:
|
||||||
- bash: Execute shell commands
|
- bash: Execute shell commands (only when needed)
|
||||||
- webfetch: Fetch web content (supports text/markdown/html formats)
|
- webfetch: Fetch web content (only for current info)
|
||||||
- read: Read files
|
- read: Read files (only when reading files)
|
||||||
- write: Create files
|
- write: Create files (only when creating files)
|
||||||
|
|
||||||
IMPORTANT: When requesting webfetch, ALWAYS provide a URL that actually exists. Do not hallucinate or guess URLs. If a URL returns 404 or errors, stop trying.
|
IMPORTANT:
|
||||||
|
- Answer from your knowledge FIRST. Only use tools when required.
|
||||||
|
- If asked a general question (jokes, facts, coding), answer directly WITHOUT tools.
|
||||||
|
- Use webfetch ONLY for real-time info (news, weather, current events).
|
||||||
|
- Use bash ONLY for file operations or system commands.
|
||||||
|
- After using a tool, provide a final answer based on the result.
|
||||||
|
- NO explanations. NO numbered lists. NO markdown code blocks.
|
||||||
|
|
||||||
Format:
|
Format when using tools:
|
||||||
TOOL: bash
|
TOOL: bash
|
||||||
ARGUMENTS: {"command": "your command here"}
|
ARGUMENTS: {"command": "your command here"}
|
||||||
|
|
||||||
TOOL: webfetch
|
Answer directly when possible. Be helpful and concise.
|
||||||
ARGUMENTS: {"url": "https://example.com", "format": "text"}
|
|
||||||
|
|
||||||
No explanations. No numbered lists. No markdown. Only output tool calls.
|
|
||||||
|
|||||||
+156
-78
@@ -96,36 +96,60 @@ def _create_response(
|
|||||||
tool_calls: list,
|
tool_calls: list,
|
||||||
finish_reason: str,
|
finish_reason: str,
|
||||||
prompt: str,
|
prompt: str,
|
||||||
request: ChatCompletionRequest
|
request: ChatCompletionRequest,
|
||||||
|
swarm_manager=None
|
||||||
) -> ChatCompletionResponse:
|
) -> ChatCompletionResponse:
|
||||||
"""Create a chat completion response.
|
"""Create a chat completion response.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
content: Response content
|
content: Response content
|
||||||
tool_calls: List of tool calls
|
tool_calls: List of tool calls
|
||||||
finish_reason: Finish reason
|
finish_reason: Finish reason
|
||||||
prompt: Original prompt for token counting
|
prompt: Original prompt for token counting
|
||||||
request: Original request
|
request: Original request
|
||||||
|
swarm_manager: Swarm manager instance (optional, for getting model name)
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
ChatCompletionResponse
|
ChatCompletionResponse
|
||||||
"""
|
"""
|
||||||
|
# Ensure content is at least an empty string (never None for OpenAI compatibility)
|
||||||
|
if content is None:
|
||||||
|
content = ""
|
||||||
|
|
||||||
prompt_tokens = count_tokens(prompt)
|
prompt_tokens = count_tokens(prompt)
|
||||||
completion_tokens = count_tokens(content)
|
completion_tokens = count_tokens(content)
|
||||||
total_tokens = prompt_tokens + completion_tokens
|
total_tokens = prompt_tokens + completion_tokens
|
||||||
|
|
||||||
|
# Get actual model name from swarm manager
|
||||||
|
model_name = request.model
|
||||||
|
system_fingerprint = None
|
||||||
|
if swarm_manager:
|
||||||
|
status = swarm_manager.get_status()
|
||||||
|
model_name = status.model_name
|
||||||
|
# Sanitize system_fingerprint to only include safe characters
|
||||||
|
import re
|
||||||
|
raw_fingerprint = model_name.lower().replace(" ", "-")
|
||||||
|
system_fingerprint = re.sub(r'[^a-z0-9\-_]', '', raw_fingerprint)
|
||||||
|
|
||||||
|
# Build message - omit tool_calls entirely if empty (OpenAI behavior)
|
||||||
|
message_kwargs = {
|
||||||
|
"role": "assistant",
|
||||||
|
"content": content
|
||||||
|
}
|
||||||
|
if tool_calls:
|
||||||
|
message_kwargs["tool_calls"] = tool_calls
|
||||||
|
|
||||||
|
message = ChatMessage(**message_kwargs)
|
||||||
|
|
||||||
return ChatCompletionResponse(
|
return ChatCompletionResponse(
|
||||||
id=f"chatcmpl-{uuid.uuid4().hex[:12]}",
|
id=f"chatcmpl-{uuid.uuid4().hex[:12]}",
|
||||||
created=int(time.time()),
|
created=int(time.time()),
|
||||||
model=request.model,
|
model=model_name,
|
||||||
choices=[
|
choices=[
|
||||||
ChatCompletionChoice(
|
ChatCompletionChoice(
|
||||||
index=0,
|
index=0,
|
||||||
message=ChatMessage(
|
message=message,
|
||||||
role="assistant",
|
logprobs=None,
|
||||||
content=content,
|
|
||||||
tool_calls=tool_calls
|
|
||||||
),
|
|
||||||
finish_reason=finish_reason
|
finish_reason=finish_reason
|
||||||
)
|
)
|
||||||
],
|
],
|
||||||
@@ -133,7 +157,9 @@ def _create_response(
|
|||||||
prompt_tokens=prompt_tokens,
|
prompt_tokens=prompt_tokens,
|
||||||
completion_tokens=completion_tokens,
|
completion_tokens=completion_tokens,
|
||||||
total_tokens=total_tokens
|
total_tokens=total_tokens
|
||||||
)
|
),
|
||||||
|
stats={},
|
||||||
|
system_fingerprint=system_fingerprint
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@@ -141,32 +167,38 @@ async def _generate_with_local_swarm(
|
|||||||
swarm_manager,
|
swarm_manager,
|
||||||
prompt: str,
|
prompt: str,
|
||||||
max_tokens: int,
|
max_tokens: int,
|
||||||
temperature: float
|
temperature: float,
|
||||||
|
stream: bool = False
|
||||||
) -> tuple[str, int, float]:
|
) -> tuple[str, int, float]:
|
||||||
"""Generate response using local swarm.
|
"""Generate response using local swarm.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
swarm_manager: Swarm manager instance
|
swarm_manager: Swarm manager instance
|
||||||
prompt: Prompt to generate from
|
prompt: Prompt to generate from
|
||||||
max_tokens: Maximum tokens to generate
|
max_tokens: Maximum tokens to generate
|
||||||
temperature: Sampling temperature
|
temperature: Sampling temperature
|
||||||
|
stream: Whether this is a streaming request
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Tuple of (response_text, tokens_generated, tokens_per_second)
|
Tuple of (response_text, tokens_generated, tokens_per_second)
|
||||||
"""
|
"""
|
||||||
result = await swarm_manager.generate(
|
try:
|
||||||
prompt=prompt,
|
result = await swarm_manager.generate(
|
||||||
max_tokens=max_tokens,
|
prompt=prompt,
|
||||||
temperature=temperature,
|
max_tokens=max_tokens,
|
||||||
use_consensus=True
|
temperature=temperature,
|
||||||
)
|
use_consensus=True
|
||||||
|
)
|
||||||
response = result.selected_response
|
|
||||||
return (
|
response = result.selected_response
|
||||||
response.text,
|
return (
|
||||||
response.tokens_generated,
|
response.text,
|
||||||
response.tokens_per_second
|
response.tokens_generated,
|
||||||
)
|
response.tokens_per_second
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
logger.exception("Error in swarm generation")
|
||||||
|
raise
|
||||||
|
|
||||||
|
|
||||||
async def _generate_with_federation(
|
async def _generate_with_federation(
|
||||||
@@ -176,13 +208,13 @@ async def _generate_with_federation(
|
|||||||
temperature: float
|
temperature: float
|
||||||
) -> tuple[str, list, str]:
|
) -> tuple[str, list, str]:
|
||||||
"""Generate response using federated swarm.
|
"""Generate response using federated swarm.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
federated_swarm: Federated swarm instance
|
federated_swarm: Federated swarm instance
|
||||||
prompt: Prompt to generate from
|
prompt: Prompt to generate from
|
||||||
max_tokens: Maximum tokens to generate
|
max_tokens: Maximum tokens to generate
|
||||||
temperature: Sampling temperature
|
temperature: Sampling temperature
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Tuple of (response_text, tool_calls, finish_reason)
|
Tuple of (response_text, tool_calls, finish_reason)
|
||||||
"""
|
"""
|
||||||
@@ -192,15 +224,15 @@ async def _generate_with_federation(
|
|||||||
temperature=temperature,
|
temperature=temperature,
|
||||||
min_peers=0
|
min_peers=0
|
||||||
)
|
)
|
||||||
|
|
||||||
content = result.final_response
|
content = result.final_response or ""
|
||||||
|
|
||||||
# Check for tool calls
|
# Check for tool calls
|
||||||
content_parsed, tool_calls_parsed = parse_tool_calls(content)
|
content_parsed, tool_calls_parsed = parse_tool_calls(content)
|
||||||
if tool_calls_parsed:
|
if tool_calls_parsed:
|
||||||
return content_parsed, tool_calls_parsed, "tool_calls"
|
return content_parsed or "", tool_calls_parsed, "tool_calls"
|
||||||
|
|
||||||
return content, [], "stop"
|
return content or "", [], "stop"
|
||||||
|
|
||||||
|
|
||||||
async def handle_chat_completion(
|
async def handle_chat_completion(
|
||||||
@@ -211,14 +243,14 @@ async def handle_chat_completion(
|
|||||||
use_opencode_tools: bool
|
use_opencode_tools: bool
|
||||||
) -> ChatCompletionResponse:
|
) -> ChatCompletionResponse:
|
||||||
"""Handle a chat completion request.
|
"""Handle a chat completion request.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
request: Chat completion request
|
request: Chat completion request
|
||||||
swarm_manager: Swarm manager instance
|
swarm_manager: Swarm manager instance
|
||||||
federated_swarm: Optional federated swarm instance
|
federated_swarm: Optional federated swarm instance
|
||||||
client_working_dir: Client working directory
|
client_working_dir: Client working directory
|
||||||
use_opencode_tools: Whether to use opencode tool definitions
|
use_opencode_tools: Whether to use opencode tool definitions
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Chat completion response
|
Chat completion response
|
||||||
"""
|
"""
|
||||||
@@ -231,10 +263,12 @@ async def handle_chat_completion(
|
|||||||
prompt = format_messages_with_tools(request.messages, None)
|
prompt = format_messages_with_tools(request.messages, None)
|
||||||
has_tools = request.tools is not None and len(request.tools) > 0
|
has_tools = request.tools is not None and len(request.tools) > 0
|
||||||
|
|
||||||
logger.debug(f"\n{'='*60}")
|
logger.info(f"\n{'='*60}")
|
||||||
logger.debug(f"REQUEST: has_tools={has_tools}, stream={request.stream}")
|
logger.info(f"CHAT COMPLETION REQUEST:")
|
||||||
logger.debug(f"MODE: {'opencode' if use_opencode_tools else 'local'} tools")
|
logger.info(f" has_tools={has_tools}, stream={request.stream}")
|
||||||
logger.debug(f"{'='*60}")
|
logger.info(f" use_opencode={use_opencode_tools}")
|
||||||
|
logger.info(f" messages={len(request.messages)}")
|
||||||
|
logger.info(f"{'='*60}")
|
||||||
|
|
||||||
# Use federation if available
|
# Use federation if available
|
||||||
if federated_swarm is not None:
|
if federated_swarm is not None:
|
||||||
@@ -244,44 +278,88 @@ async def handle_chat_completion(
|
|||||||
content, tool_calls, finish_reason = await _generate_with_federation(
|
content, tool_calls, finish_reason = await _generate_with_federation(
|
||||||
federated_swarm, prompt, request.max_tokens or 1024, request.temperature or 0.7
|
federated_swarm, prompt, request.max_tokens or 1024, request.temperature or 0.7
|
||||||
)
|
)
|
||||||
return _create_response(content, tool_calls, finish_reason, prompt, request)
|
return _create_response(content, tool_calls, finish_reason, prompt, request, swarm_manager)
|
||||||
|
|
||||||
# Use local swarm
|
|
||||||
logger.debug("Using local swarm generation")
|
|
||||||
response_text, tokens_generated, tps = await _generate_with_local_swarm(
|
|
||||||
swarm_manager, prompt, request.max_tokens or 1024, request.temperature or 0.7
|
|
||||||
)
|
|
||||||
|
|
||||||
logger.debug(f"DEBUG: Generated response (tokens={tokens_generated}, t/s={tps:.1f})")
|
|
||||||
logger.debug(f"DEBUG: Response preview: {response_text[:200]}...")
|
|
||||||
|
|
||||||
# Parse tool calls if tools were provided
|
|
||||||
content = response_text
|
|
||||||
tool_calls = []
|
|
||||||
finish_reason = "stop"
|
|
||||||
|
|
||||||
if has_tools:
|
# Build conversation history
|
||||||
logger.debug(f"DEBUG: Parsing tool calls from response...")
|
messages = list(request.messages)
|
||||||
content, tool_calls_parsed = parse_tool_calls(response_text)
|
|
||||||
logger.debug(f"DEBUG: parse_tool_calls returned: content_len={len(content)}, parsed={tool_calls_parsed is not None}")
|
|
||||||
|
|
||||||
if tool_calls_parsed:
|
|
||||||
logger.debug(f" 🔧 Model requesting {len(tool_calls_parsed)} tool(s)...")
|
|
||||||
executor = get_tool_executor()
|
|
||||||
if executor:
|
|
||||||
logger.debug(f" 🔗 Tool executor: {executor.tool_host_url or 'local'}")
|
|
||||||
else:
|
|
||||||
logger.debug(f" ⚠️ No tool executor configured!")
|
|
||||||
|
|
||||||
# Execute tools
|
|
||||||
tool_results_str = await _execute_tools(tool_calls_parsed, client_working_dir, executor)
|
|
||||||
content = tool_results_str
|
|
||||||
finish_reason = "stop"
|
|
||||||
tool_calls = [] # Clear tool_calls since we executed them
|
|
||||||
logger.debug(f" ✅ All tools executed, returning results")
|
|
||||||
else:
|
|
||||||
logger.debug(f"DEBUG: No tool calls parsed from response")
|
|
||||||
else:
|
|
||||||
logger.debug(f"DEBUG: No tools requested, returning normal response")
|
|
||||||
|
|
||||||
return _create_response(content, tool_calls, finish_reason, prompt, request)
|
# Initialize iteration counter and response text
|
||||||
|
iteration = 0
|
||||||
|
max_iterations = 3
|
||||||
|
response_text = ""
|
||||||
|
|
||||||
|
while iteration < max_iterations:
|
||||||
|
iteration += 1
|
||||||
|
logger.info(f"--- Tool Execution Iteration {iteration} ---")
|
||||||
|
|
||||||
|
# Generate response
|
||||||
|
logger.debug(f"Generating response...")
|
||||||
|
response_text, tokens_generated, tps = await _generate_with_local_swarm(
|
||||||
|
swarm_manager, prompt, request.max_tokens or 1024, request.temperature or 0.7
|
||||||
|
)
|
||||||
|
|
||||||
|
logger.info(f"Generated response ({len(response_text)} chars, {tokens_generated} tokens)")
|
||||||
|
logger.debug(f"Response: {response_text[:200]}...")
|
||||||
|
|
||||||
|
# Check for tool calls
|
||||||
|
parsed_content, tool_calls_parsed = parse_tool_calls(response_text)
|
||||||
|
|
||||||
|
if not tool_calls_parsed:
|
||||||
|
# No more tools - this is the final answer
|
||||||
|
logger.info(f"✅ Final answer (no tools) after {iteration} iteration(s)")
|
||||||
|
return _create_response(parsed_content, [], "stop", prompt, request, swarm_manager)
|
||||||
|
|
||||||
|
# Tools detected - execute them
|
||||||
|
logger.info(f"🔧 Found {len(tool_calls_parsed)} tool call(s)")
|
||||||
|
for i, tc in enumerate(tool_calls_parsed):
|
||||||
|
tool_name = tc.get("function", {}).get("name", "")
|
||||||
|
args_str = tc.get("function", {}).get("arguments", "{}")
|
||||||
|
logger.info(f" [{i+1}] {tool_name}: {args_str[:100]}...")
|
||||||
|
|
||||||
|
# Add assistant message to history
|
||||||
|
messages.append(ChatMessage(role="assistant", content=response_text))
|
||||||
|
|
||||||
|
# Execute all tools
|
||||||
|
logger.info(f"⏱️ Executing tools...")
|
||||||
|
tool_results_str = await _execute_tools(tool_calls_parsed, client_working_dir, get_tool_executor())
|
||||||
|
|
||||||
|
# Add tool result to history with STOP instruction
|
||||||
|
# The model needs to be told explicitly to STOP calling tools
|
||||||
|
tool_result_with_instruction = (
|
||||||
|
f"{tool_results_str}\n\n"
|
||||||
|
f"IMPORTANT: You have received the tool result above. "
|
||||||
|
f"DO NOT call any more tools. Provide your final answer now."
|
||||||
|
)
|
||||||
|
messages.append(ChatMessage(role="tool", content=tool_result_with_instruction))
|
||||||
|
logger.info(f"✅ Tools executed ({len(tool_results_str)} chars)")
|
||||||
|
|
||||||
|
# Continue loop - generate response with tool results
|
||||||
|
logger.info(f"🔄 Generating response with tool results...")
|
||||||
|
|
||||||
|
# Format with tool results (but DON'T include tool instruction - model should just use results)
|
||||||
|
next_prompt = format_messages_with_tools(messages, None if use_opencode_tools else request.tools)
|
||||||
|
|
||||||
|
response_text, tokens_generated, tps = await _generate_with_local_swarm(
|
||||||
|
swarm_manager, next_prompt, request.max_tokens or 1024, request.temperature or 0.7
|
||||||
|
)
|
||||||
|
|
||||||
|
logger.info(f"Generated with tool results ({len(response_text)} chars, {tokens_generated} tokens)")
|
||||||
|
logger.debug(f"Response: {response_text[:200]}...")
|
||||||
|
|
||||||
|
# Check for more tools in the new response
|
||||||
|
parsed_content, tool_calls_parsed = parse_tool_calls(response_text)
|
||||||
|
|
||||||
|
if not tool_calls_parsed:
|
||||||
|
# No more tools - final answer
|
||||||
|
logger.info(f"✅ Final answer (after tool execution) after {iteration} iteration(s)")
|
||||||
|
return _create_response(parsed_content, [], "stop", prompt, request, swarm_manager)
|
||||||
|
|
||||||
|
# More tools detected - continue loop
|
||||||
|
logger.info(f"🔧 More tools found - continuing loop")
|
||||||
|
|
||||||
|
# Max iterations reached - force return last response
|
||||||
|
logger.warning(f"⚠️ Max tool iterations ({max_iterations}) reached")
|
||||||
|
logger.warning(f"⚠️ Returning last response (may include incomplete tool call)")
|
||||||
|
return _create_response(response_text, [], "stop", prompt, request, swarm_manager)
|
||||||
|
|||||||
+16
-10
@@ -4,7 +4,7 @@ Pydantic models matching OpenAI's API specification.
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
from typing import List, Optional, Literal, Dict, Any, Union
|
from typing import List, Optional, Literal, Dict, Any, Union
|
||||||
from pydantic import BaseModel, Field
|
from pydantic import BaseModel, Field, ConfigDict
|
||||||
|
|
||||||
|
|
||||||
class FunctionDefinition(BaseModel):
|
class FunctionDefinition(BaseModel):
|
||||||
@@ -29,14 +29,16 @@ class ToolCall(BaseModel):
|
|||||||
|
|
||||||
class ChatMessage(BaseModel):
|
class ChatMessage(BaseModel):
|
||||||
"""A chat message."""
|
"""A chat message."""
|
||||||
role: Literal["system", "user", "assistant", "tool"] = Field(..., description="Role of the message sender")
|
role: Literal["system", "user", "assistant", "tool"] = Field(..., description="Role of message sender")
|
||||||
content: Optional[str] = Field(default=None, description="Message content")
|
content: str = Field(default="", description="Message content")
|
||||||
tool_calls: Optional[List[ToolCall]] = Field(default_factory=list, description="Tool calls from assistant")
|
tool_calls: Optional[List[ToolCall]] = Field(default=None, description="Tool calls from assistant")
|
||||||
#tool_call_id: Optional[str] = Field(default=None, description="ID of tool call this message is responding to")
|
#tool_call_id: Optional[str] = Field(default=None, description="ID of tool call this message is responding to")
|
||||||
#name: Optional[str] = Field(default=None, description="Name of the tool/function")
|
#name: Optional[str] = Field(default=None, description="Name of the tool/function")
|
||||||
|
|
||||||
class Config:
|
model_config = ConfigDict(
|
||||||
exclude_none = True
|
# Use Pydantic's exclude_none to omit tool_calls when None
|
||||||
|
exclude_none=True
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class ChatCompletionRequest(BaseModel):
|
class ChatCompletionRequest(BaseModel):
|
||||||
@@ -50,9 +52,9 @@ class ChatCompletionRequest(BaseModel):
|
|||||||
stop: Optional[List[str]] = Field(default=None, description="Stop sequences")
|
stop: Optional[List[str]] = Field(default=None, description="Stop sequences")
|
||||||
tools: Optional[List[Tool]] = Field(default=None, description="List of tools the model may call")
|
tools: Optional[List[Tool]] = Field(default=None, description="List of tools the model may call")
|
||||||
tool_choice: Optional[Union[str, Dict[str, Any]]] = Field(default="auto", description="How to choose tools")
|
tool_choice: Optional[Union[str, Dict[str, Any]]] = Field(default="auto", description="How to choose tools")
|
||||||
|
|
||||||
class Config:
|
model_config = ConfigDict(
|
||||||
json_schema_extra = {
|
json_schema_extra={
|
||||||
"example": {
|
"example": {
|
||||||
"model": "local-swarm",
|
"model": "local-swarm",
|
||||||
"messages": [
|
"messages": [
|
||||||
@@ -62,12 +64,14 @@ class ChatCompletionRequest(BaseModel):
|
|||||||
"temperature": 0.7
|
"temperature": 0.7
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class ChatCompletionChoice(BaseModel):
|
class ChatCompletionChoice(BaseModel):
|
||||||
"""A choice in the chat completion response."""
|
"""A choice in the chat completion response."""
|
||||||
index: int = Field(..., description="Choice index")
|
index: int = Field(..., description="Choice index")
|
||||||
message: ChatMessage = Field(..., description="Generated message")
|
message: ChatMessage = Field(..., description="Generated message")
|
||||||
|
logprobs: Optional[Any] = Field(default=None, description="Log probabilities")
|
||||||
finish_reason: Optional[str] = Field(default="stop", description="Reason for finishing (stop, length, tool_calls, etc.)")
|
finish_reason: Optional[str] = Field(default="stop", description="Reason for finishing (stop, length, tool_calls, etc.)")
|
||||||
|
|
||||||
|
|
||||||
@@ -87,6 +91,8 @@ class ChatCompletionResponse(BaseModel):
|
|||||||
model: str = Field(..., description="Model used")
|
model: str = Field(..., description="Model used")
|
||||||
choices: List[ChatCompletionChoice] = Field(..., description="Generated choices")
|
choices: List[ChatCompletionChoice] = Field(..., description="Generated choices")
|
||||||
usage: UsageInfo = Field(..., description="Token usage")
|
usage: UsageInfo = Field(..., description="Token usage")
|
||||||
|
stats: Dict[str, Any] = Field(default_factory=dict, description="Additional stats")
|
||||||
|
system_fingerprint: Optional[str] = Field(default=None, description="System fingerprint")
|
||||||
|
|
||||||
|
|
||||||
class ChatCompletionStreamChoice(BaseModel):
|
class ChatCompletionStreamChoice(BaseModel):
|
||||||
|
|||||||
+76
-1
@@ -224,6 +224,43 @@ def set_federated_swarm(swarm):
|
|||||||
federated_swarm = swarm
|
federated_swarm = swarm
|
||||||
|
|
||||||
|
|
||||||
|
async def _stream_response(response: ChatCompletionResponse):
|
||||||
|
"""Stream a chat completion response as Server-Sent Events.
|
||||||
|
|
||||||
|
For compatibility with OpenAI format, we use delta format for streaming.
|
||||||
|
The response is sent as a single chunk since we don't support
|
||||||
|
true token-by-token streaming yet.
|
||||||
|
"""
|
||||||
|
import json
|
||||||
|
from api.models import ChatCompletionStreamResponse, ChatCompletionStreamChoice
|
||||||
|
|
||||||
|
# Convert to streaming format with delta
|
||||||
|
message = response.choices[0].message
|
||||||
|
choice = ChatCompletionStreamChoice(
|
||||||
|
index=0,
|
||||||
|
delta={"content": message.content},
|
||||||
|
finish_reason="stop"
|
||||||
|
)
|
||||||
|
|
||||||
|
stream_response = ChatCompletionStreamResponse(
|
||||||
|
id=response.id,
|
||||||
|
created=response.created,
|
||||||
|
model=response.model,
|
||||||
|
choices=[choice]
|
||||||
|
)
|
||||||
|
|
||||||
|
# Send as SSE event
|
||||||
|
data = stream_response.model_dump_json(exclude_none=True)
|
||||||
|
logger.debug(f"Streaming SSE data (delta format): {len(data)} chars")
|
||||||
|
|
||||||
|
yield f"data: {data}\n\n"
|
||||||
|
|
||||||
|
# Send done event
|
||||||
|
yield "data: [DONE]\n\n"
|
||||||
|
|
||||||
|
logger.debug(f"Streaming complete")
|
||||||
|
|
||||||
|
|
||||||
@router.post("/v1/chat/completions")
|
@router.post("/v1/chat/completions")
|
||||||
async def chat_completions(request: ChatCompletionRequest, fastapi_request: Request):
|
async def chat_completions(request: ChatCompletionRequest, fastapi_request: Request):
|
||||||
"""Generate chat completion."""
|
"""Generate chat completion."""
|
||||||
@@ -239,6 +276,10 @@ async def chat_completions(request: ChatCompletionRequest, fastapi_request: Requ
|
|||||||
client_working_dir = fastapi_request.headers.get("X-Client-Working-Dir")
|
client_working_dir = fastapi_request.headers.get("X-Client-Working-Dir")
|
||||||
|
|
||||||
try:
|
try:
|
||||||
|
logger.info(f"📥 Processing chat completion request...")
|
||||||
|
logger.info(f" Stream: {request.stream}")
|
||||||
|
logger.info(f" Model: {request.model}")
|
||||||
|
|
||||||
response = await handle_chat_completion(
|
response = await handle_chat_completion(
|
||||||
request=request,
|
request=request,
|
||||||
swarm_manager=swarm_manager,
|
swarm_manager=swarm_manager,
|
||||||
@@ -246,7 +287,41 @@ async def chat_completions(request: ChatCompletionRequest, fastapi_request: Requ
|
|||||||
client_working_dir=client_working_dir,
|
client_working_dir=client_working_dir,
|
||||||
use_opencode_tools=_USE_OPENCODE_TOOLS
|
use_opencode_tools=_USE_OPENCODE_TOOLS
|
||||||
)
|
)
|
||||||
return response
|
|
||||||
|
logger.info(f"✅ Response generated successfully")
|
||||||
|
logger.debug(f"Response object type: {type(response)}")
|
||||||
|
|
||||||
|
# Handle streaming if requested
|
||||||
|
if request.stream:
|
||||||
|
logger.info(f"🌊 Returning streaming response")
|
||||||
|
return StreamingResponse(
|
||||||
|
_stream_response(response),
|
||||||
|
media_type="text/event-stream"
|
||||||
|
)
|
||||||
|
|
||||||
|
# For non-streaming, return JSON with proper handling of None fields:
|
||||||
|
# - tool_calls: omit when None (no tools)
|
||||||
|
# - logprobs: always include as null (even when None)
|
||||||
|
from fastapi.responses import JSONResponse
|
||||||
|
|
||||||
|
logger.info(f"📋 Returning JSON response")
|
||||||
|
|
||||||
|
# Build response dict with custom handling
|
||||||
|
response_dict = response.model_dump(exclude_none=True)
|
||||||
|
|
||||||
|
# Ensure logprobs is always present (as null if not available)
|
||||||
|
for choice in response_dict.get('choices', []):
|
||||||
|
if 'logprobs' not in choice:
|
||||||
|
choice['logprobs'] = None
|
||||||
|
|
||||||
|
logger.debug(f"Response dict: {response_dict}")
|
||||||
|
|
||||||
|
return JSONResponse(
|
||||||
|
content=response_dict,
|
||||||
|
status_code=200
|
||||||
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.exception("Error in chat completion")
|
logger.exception("Error in chat completion")
|
||||||
|
logger.error(f"Error type: {type(e).__name__}")
|
||||||
|
logger.error(f"Error message: {str(e)}")
|
||||||
raise HTTPException(status_code=500, detail=f"Generation failed: {str(e)}")
|
raise HTTPException(status_code=500, detail=f"Generation failed: {str(e)}")
|
||||||
|
|||||||
@@ -76,7 +76,8 @@ class MainRunner:
|
|||||||
config = select_optimal_model(
|
config = select_optimal_model(
|
||||||
self.hardware,
|
self.hardware,
|
||||||
preferred_model=self.args.model,
|
preferred_model=self.args.model,
|
||||||
force_instances=self.args.instances
|
force_instances=self.args.instances,
|
||||||
|
use_mlx=None # Auto-detect based on hardware
|
||||||
)
|
)
|
||||||
|
|
||||||
if not config:
|
if not config:
|
||||||
|
|||||||
+115
-1
@@ -123,6 +123,10 @@ class ModelRegistry:
|
|||||||
return None
|
return None
|
||||||
|
|
||||||
meta = self._metadata[model_id]
|
meta = self._metadata[model_id]
|
||||||
|
# Ensure meta is a dict (not a string like "_comment")
|
||||||
|
if not isinstance(meta, dict):
|
||||||
|
return None
|
||||||
|
|
||||||
sizes = self._mlx_sizes if use_mlx else self._gguf_sizes
|
sizes = self._mlx_sizes if use_mlx else self._gguf_sizes
|
||||||
quality_map = self._get_quality_map(use_mlx)
|
quality_map = self._get_quality_map(use_mlx)
|
||||||
|
|
||||||
@@ -154,7 +158,10 @@ class ModelRegistry:
|
|||||||
def list_models(self, use_mlx: bool = False) -> List[Model]:
|
def list_models(self, use_mlx: bool = False) -> List[Model]:
|
||||||
"""List all available models."""
|
"""List all available models."""
|
||||||
models = []
|
models = []
|
||||||
for model_id in self._metadata.keys():
|
for model_id, meta in self._metadata.items():
|
||||||
|
# Skip non-dict entries (like _comment)
|
||||||
|
if not isinstance(meta, dict):
|
||||||
|
continue
|
||||||
model = self.get_model(model_id, use_mlx)
|
model = self.get_model(model_id, use_mlx)
|
||||||
if model:
|
if model:
|
||||||
models.append(model)
|
models.append(model)
|
||||||
@@ -192,3 +199,110 @@ def get_model(model_id: str, use_mlx: bool = False) -> Optional[Model]:
|
|||||||
def list_models(use_mlx: bool = False) -> List[Model]:
|
def list_models(use_mlx: bool = False) -> List[Model]:
|
||||||
"""List all available models."""
|
"""List all available models."""
|
||||||
return _registry.list_models(use_mlx)
|
return _registry.list_models(use_mlx)
|
||||||
|
|
||||||
|
|
||||||
|
def get_model_hf_repo(model_id: str, variant: ModelVariant, quant: QuantizationConfig) -> Optional[str]:
|
||||||
|
"""Get HuggingFace repository ID for a GGUF model.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model_id: Model identifier
|
||||||
|
variant: Model variant (size)
|
||||||
|
quant: Quantization config
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
HuggingFace repo ID (e.g., "Qwen/Qwen2.5-Coder-7B-Instruct-GGUF") or None if unknown
|
||||||
|
"""
|
||||||
|
# Get the base repo from metadata
|
||||||
|
if model_id not in _registry._metadata:
|
||||||
|
return None
|
||||||
|
|
||||||
|
meta = _registry._metadata[model_id]
|
||||||
|
if not isinstance(meta, dict):
|
||||||
|
return None
|
||||||
|
|
||||||
|
base_repo = meta.get("hf_repo")
|
||||||
|
if not base_repo:
|
||||||
|
return None
|
||||||
|
|
||||||
|
# Convert variant size (e.g., "14b" to "-14B") and construct repo ID
|
||||||
|
size_suffix = f"-{variant.size.upper()}"
|
||||||
|
|
||||||
|
# For GGUF, add -Instruct-GGUF suffix
|
||||||
|
repo_id = f"{base_repo}{size_suffix}-Instruct-GGUF"
|
||||||
|
|
||||||
|
return repo_id
|
||||||
|
|
||||||
|
|
||||||
|
def get_model_hf_repo_mlx(model_id: str, variant: ModelVariant, quant: QuantizationConfig) -> Optional[str]:
|
||||||
|
"""Get HuggingFace repository ID for an MLX model.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model_id: Model identifier
|
||||||
|
variant: Model variant (size)
|
||||||
|
quant: Quantization config
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
HuggingFace repo ID (e.g., "mlx-community/Qwen2.5-Coder-14B-Instruct-4bit") or None if unknown
|
||||||
|
"""
|
||||||
|
# Get the base repo from metadata
|
||||||
|
if model_id not in _registry._metadata:
|
||||||
|
return None
|
||||||
|
|
||||||
|
meta = _registry._metadata[model_id]
|
||||||
|
if not isinstance(meta, dict):
|
||||||
|
return None
|
||||||
|
|
||||||
|
base_repo = meta.get("hf_repo")
|
||||||
|
if not base_repo:
|
||||||
|
return None
|
||||||
|
|
||||||
|
# MLX models are typically in mlx-community namespace
|
||||||
|
# Format: mlx-community/{ModelName}-{Size}-{Quantization}
|
||||||
|
# For example: mlx-community/Qwen2.5-Coder-14B-Instruct-4bit
|
||||||
|
|
||||||
|
# Convert variant size (e.g., "14b" to "-14B")
|
||||||
|
size_suffix = f"-{variant.size.upper()}"
|
||||||
|
|
||||||
|
# Add quantization suffix (e.g., "-4bit" for MLX quantization names)
|
||||||
|
quant_suffix = f"-{quant.name}"
|
||||||
|
|
||||||
|
# Construct the full repo name
|
||||||
|
model_name = base_repo.split('/')[-1] # Get just the model name, not the org
|
||||||
|
repo_id = f"mlx-community/{model_name}{size_suffix}-Instruct{quant_suffix}"
|
||||||
|
|
||||||
|
return repo_id
|
||||||
|
|
||||||
|
|
||||||
|
def get_model_filename(model_id: str, variant: ModelVariant, quant: QuantizationConfig) -> str:
|
||||||
|
"""Get the filename for a GGUF model file.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model_id: Model identifier
|
||||||
|
variant: Model variant (size)
|
||||||
|
quant: Quantization config
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
GGUF filename (e.g., "qwen2.5-coder-14b-instruct-q4_k_m.gguf")
|
||||||
|
"""
|
||||||
|
# Extract model name from metadata
|
||||||
|
if model_id not in _registry._metadata:
|
||||||
|
meta = {"name": model_id, "hf_repo": model_id}
|
||||||
|
else:
|
||||||
|
meta = _registry._metadata[model_id]
|
||||||
|
|
||||||
|
if not isinstance(meta, dict):
|
||||||
|
meta = {"name": model_id, "hf_repo": model_id}
|
||||||
|
|
||||||
|
# Use the base repo name or model name
|
||||||
|
base_name = meta.get("hf_repo", meta.get("name", model_id))
|
||||||
|
# Remove org prefix if present
|
||||||
|
if '/' in base_name:
|
||||||
|
base_name = base_name.split('/')[-1]
|
||||||
|
|
||||||
|
# Standard GGUF naming (all lowercase): {model}-{variant}-instruct-{quantization}.gguf
|
||||||
|
# For example: qwen2.5-coder-14b-instruct-q4_k_m.gguf
|
||||||
|
variant_size = f"-{variant.size.lower()}"
|
||||||
|
quant_name = quant.name.lower()
|
||||||
|
filename = f"{base_name.lower()}{variant_size}-instruct-{quant_name}.gguf"
|
||||||
|
|
||||||
|
return filename
|
||||||
|
|||||||
@@ -76,11 +76,12 @@ def select_optimal_model(
|
|||||||
force_instances: Optional[int] = None,
|
force_instances: Optional[int] = None,
|
||||||
context_size: int = 32768,
|
context_size: int = 32768,
|
||||||
offload_percent: float = 0.0,
|
offload_percent: float = 0.0,
|
||||||
use_mlx: bool = False
|
use_mlx: Optional[bool] = None
|
||||||
) -> Optional[ModelConfig]:
|
) -> Optional[ModelConfig]:
|
||||||
"""Select the optimal model configuration for given hardware."""
|
"""Select the optimal model configuration for given hardware."""
|
||||||
if use_mlx is None and hardware.is_apple_silicon:
|
# Auto-detect MLX usage for Apple Silicon if not explicitly set
|
||||||
use_mlx = True
|
if use_mlx is None:
|
||||||
|
use_mlx = hardware.is_apple_silicon
|
||||||
|
|
||||||
available_vram, _ = get_available_memory_with_offload(hardware, offload_percent)
|
available_vram, _ = get_available_memory_with_offload(hardware, offload_percent)
|
||||||
|
|
||||||
|
|||||||
@@ -0,0 +1,140 @@
|
|||||||
|
"""Test Apple Silicon MLX auto-detection and download."""
|
||||||
|
|
||||||
|
import sys
|
||||||
|
import os
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
# Add src to path
|
||||||
|
sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..', 'src'))
|
||||||
|
|
||||||
|
def test_apple_silicon_mlx_selection():
|
||||||
|
"""Test that Apple Silicon correctly selects MLX models."""
|
||||||
|
from hardware.detector import HardwareProfile, GPUInfo
|
||||||
|
from models.selector import select_optimal_model
|
||||||
|
|
||||||
|
# Mock Apple Silicon hardware
|
||||||
|
class MockAppleHardware:
|
||||||
|
os = "darwin"
|
||||||
|
cpu_cores = 12
|
||||||
|
ram_gb = 24.0
|
||||||
|
ram_available_gb = 12.0
|
||||||
|
is_apple_silicon = True
|
||||||
|
has_dedicated_gpu = False
|
||||||
|
gpu = GPUInfo(name="Apple Silicon GPU", vram_gb=24.0, driver_version=None)
|
||||||
|
available_memory_gb = 12.0
|
||||||
|
recommended_memory_gb = 12.0
|
||||||
|
|
||||||
|
hardware = MockAppleHardware()
|
||||||
|
|
||||||
|
# Test auto-detection (use_mlx=None)
|
||||||
|
print("=" * 60)
|
||||||
|
print("Apple Silicon MLX Auto-Detection Test")
|
||||||
|
print("=" * 60)
|
||||||
|
|
||||||
|
print("\n1. Testing auto-detection (use_mlx=None)...")
|
||||||
|
config = select_optimal_model(hardware, use_mlx=None)
|
||||||
|
|
||||||
|
assert config is not None, "Should find a model"
|
||||||
|
print(f" ✓ Model selected: {config.model.name}")
|
||||||
|
|
||||||
|
# Verify quantization is MLX format (4bit, 8bit, etc.)
|
||||||
|
print("\n2. Verifying MLX quantization format...")
|
||||||
|
is_mlx_format = 'bit' in config.quantization.name.lower()
|
||||||
|
assert is_mlx_format, f"Quantization should be MLX format (4bit/8bit), got {config.quantization.name}"
|
||||||
|
print(f" ✓ Quantization: {config.quantization.name} (MLX format)")
|
||||||
|
|
||||||
|
# Test repository name generation
|
||||||
|
print("\n3. Testing MLX repository name generation...")
|
||||||
|
from models.registry import get_model_hf_repo_mlx
|
||||||
|
|
||||||
|
mlx_repo = get_model_hf_repo_mlx(config.model.id, config.variant, config.quantization)
|
||||||
|
assert mlx_repo is not None, "MLX repository should be generated"
|
||||||
|
assert "mlx-community" in mlx_repo, "Should use mlx-community namespace"
|
||||||
|
assert "-Instruct-" in mlx_repo, "Should have -Instruct- suffix"
|
||||||
|
assert config.quantization.name in mlx_repo, "Should include quantization"
|
||||||
|
print(f" ✓ Repository: {mlx_repo}")
|
||||||
|
|
||||||
|
# Verify it's NOT using GGUF format
|
||||||
|
print("\n4. Verifying NOT using GGUF format...")
|
||||||
|
has_gguf = 'q4_k_m' in config.quantization.name or 'q5_k_m' in config.quantization.name
|
||||||
|
has_gguf_suffix = '-GGUF' in mlx_repo
|
||||||
|
assert not has_gguf, f"Should not use GGUF quantization names"
|
||||||
|
assert not has_gguf_suffix, f"Should not use GGUF repository suffix"
|
||||||
|
print(f" ✓ Not using GGUF format")
|
||||||
|
|
||||||
|
print("\n" + "=" * 60)
|
||||||
|
print("All Apple Silicon MLX tests passed!")
|
||||||
|
print("=" * 60)
|
||||||
|
|
||||||
|
|
||||||
|
def test_nvidia_gpu_gguf_selection():
|
||||||
|
"""Test that NVIDIA GPU correctly selects GGUF models."""
|
||||||
|
from hardware.detector import HardwareProfile, GPUInfo
|
||||||
|
from models.selector import select_optimal_model
|
||||||
|
|
||||||
|
# Mock NVIDIA hardware
|
||||||
|
class MockNvidiaHardware:
|
||||||
|
os = "linux"
|
||||||
|
cpu_cores = 8
|
||||||
|
ram_gb = 32.0
|
||||||
|
ram_available_gb = 20.0
|
||||||
|
is_apple_silicon = False
|
||||||
|
has_dedicated_gpu = True
|
||||||
|
gpu = GPUInfo(name="NVIDIA RTX 4090", vram_gb=24.0, driver_version="550.80")
|
||||||
|
available_memory_gb = 20.0
|
||||||
|
recommended_memory_gb = 20.0
|
||||||
|
|
||||||
|
hardware = MockNvidiaHardware()
|
||||||
|
|
||||||
|
print("\n" + "=" * 60)
|
||||||
|
print("NVIDIA GPU GGUF Auto-Detection Test")
|
||||||
|
print("=" * 60)
|
||||||
|
|
||||||
|
print("\n1. Testing auto-detection (use_mlx=None)...")
|
||||||
|
config = select_optimal_model(hardware, use_mlx=None)
|
||||||
|
|
||||||
|
assert config is not None, "Should find a model"
|
||||||
|
print(f" ✓ Model selected: {config.model.name}")
|
||||||
|
|
||||||
|
# Verify quantization is GGUF format (q4_k_m, q5_k_m, etc.)
|
||||||
|
print("\n2. Verifying GGUF quantization format...")
|
||||||
|
is_gguf_format = 'q' in config.quantization.name.lower()
|
||||||
|
assert is_gguf_format, f"Quantization should be GGUF format (q4_k_m/q5_k_m), got {config.quantization.name}"
|
||||||
|
print(f" ✓ Quantization: {config.quantization.name} (GGUF format)")
|
||||||
|
|
||||||
|
# Test repository name generation
|
||||||
|
print("\n3. Testing GGUF repository name generation...")
|
||||||
|
from models.registry import get_model_hf_repo
|
||||||
|
|
||||||
|
gguf_repo = get_model_hf_repo(config.model.id, config.variant, config.quantization)
|
||||||
|
assert gguf_repo is not None, "GGUF repository should be generated"
|
||||||
|
assert "-GGUF" in gguf_repo, "Should have -GGUF suffix"
|
||||||
|
print(f" ✓ Repository: {gguf_repo}")
|
||||||
|
|
||||||
|
# Verify it's NOT using MLX format
|
||||||
|
print("\n4. Verifying NOT using MLX format...")
|
||||||
|
has_mlx_format = 'bit' in config.quantization.name.lower() and config.quantization.name not in ['q4_k_m', 'q5_k_m', 'q6_k']
|
||||||
|
has_mlx_namespace = 'mlx-community' in gguf_repo
|
||||||
|
assert not has_mlx_namespace, f"Should not use mlx-community namespace"
|
||||||
|
print(f" ✓ Not using MLX format")
|
||||||
|
|
||||||
|
print("\n" + "=" * 60)
|
||||||
|
print("All NVIDIA GPU GGUF tests passed!")
|
||||||
|
print("=" * 60)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
try:
|
||||||
|
test_apple_silicon_mlx_selection()
|
||||||
|
test_nvidia_gpu_gguf_selection()
|
||||||
|
print("\n" + "=" * 60)
|
||||||
|
print("ALL AUTO-DETECTION TESTS PASSED!")
|
||||||
|
print("=" * 60)
|
||||||
|
except AssertionError as e:
|
||||||
|
print(f"\n❌ Test failed: {e}")
|
||||||
|
sys.exit(1)
|
||||||
|
except Exception as e:
|
||||||
|
print(f"\n❌ Test error: {e}")
|
||||||
|
import traceback
|
||||||
|
traceback.print_exc()
|
||||||
|
sys.exit(1)
|
||||||
@@ -0,0 +1,100 @@
|
|||||||
|
"""End-to-end test for tool execution with a mock server.
|
||||||
|
|
||||||
|
This tests the complete flow:
|
||||||
|
1. Model generates tool call
|
||||||
|
2. Tools are executed
|
||||||
|
3. Response is generated based on tool results
|
||||||
|
"""
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
import sys
|
||||||
|
import os
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..', 'src'))
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_tool_flow():
|
||||||
|
"""Test the tool execution flow end-to-end."""
|
||||||
|
|
||||||
|
# Import after path is set
|
||||||
|
from api.models import ChatMessage, ChatCompletionRequest
|
||||||
|
from api.tool_parser import parse_tool_calls
|
||||||
|
from api.formatting import format_messages_with_tools
|
||||||
|
from tools.executor import ToolExecutor
|
||||||
|
|
||||||
|
print("=" * 60)
|
||||||
|
print("End-to-End Tool Execution Test")
|
||||||
|
print("=" * 60)
|
||||||
|
|
||||||
|
# Test 1: Parse tool call from model response
|
||||||
|
print("\n1. Testing tool parsing...")
|
||||||
|
model_response = "TOOL: bash\nARGUMENTS: {\"command\": \"echo hello\"}"
|
||||||
|
|
||||||
|
content, tool_calls = parse_tool_calls(model_response)
|
||||||
|
assert tool_calls is not None, "Should parse tool call"
|
||||||
|
assert len(tool_calls) == 1, "Should have one tool call"
|
||||||
|
assert tool_calls[0]["function"]["name"] == "bash", "Should be bash tool"
|
||||||
|
print(f" ✓ Parsed tool: {tool_calls[0]['function']['name']}")
|
||||||
|
|
||||||
|
# Test 2: Simulate tool result and format for next prompt
|
||||||
|
print("\n2. Testing tool result formatting...")
|
||||||
|
tool_result = "hello\n"
|
||||||
|
|
||||||
|
# Build conversation history
|
||||||
|
messages = [
|
||||||
|
ChatMessage(role="user", content="Run echo hello"),
|
||||||
|
ChatMessage(role="assistant", content=model_response),
|
||||||
|
ChatMessage(role="tool", content=tool_result)
|
||||||
|
]
|
||||||
|
|
||||||
|
# Format for next generation
|
||||||
|
next_prompt = format_messages_with_tools(messages, None)
|
||||||
|
assert "tool" in next_prompt.lower(), "Prompt should include tool result"
|
||||||
|
assert "hello" in next_prompt, "Prompt should include tool output"
|
||||||
|
print(f" ✓ Tool result formatted for next prompt")
|
||||||
|
|
||||||
|
# Test 3: Verify loop detection
|
||||||
|
print("\n3. Testing loop detection...")
|
||||||
|
seen_tools = set()
|
||||||
|
|
||||||
|
# First tool call
|
||||||
|
tc1 = [{"function": {"name": "bash", "arguments": '{"command": "ls"}'}}]
|
||||||
|
sig1 = "bash:{'command': \"ls\"}'[:50]"
|
||||||
|
seen_tools.add(sig1)
|
||||||
|
print(f" ✓ First tool call tracked")
|
||||||
|
|
||||||
|
# Duplicate tool call
|
||||||
|
tc2 = tc1
|
||||||
|
sig2 = sig1
|
||||||
|
is_duplicate = sig2 in seen_tools
|
||||||
|
assert is_duplicate, "Should detect duplicate"
|
||||||
|
print(f" ✓ Duplicate tool call detected")
|
||||||
|
|
||||||
|
# Test 4: Verify tool result truncation
|
||||||
|
print("\n4. Testing tool result truncation...")
|
||||||
|
long_result = "a" * 3000
|
||||||
|
max_length = 2000
|
||||||
|
|
||||||
|
if len(long_result) > max_length:
|
||||||
|
truncated = long_result[:max_length] + "\n[...truncated...]"
|
||||||
|
assert len(truncated) == max_length + len("\n[...truncated...]"), "Should truncate properly"
|
||||||
|
print(f" ✓ Tool result truncated from {len(long_result)} to {len(truncated)} chars")
|
||||||
|
|
||||||
|
print("\n" + "=" * 60)
|
||||||
|
print("All end-to-end tests passed!")
|
||||||
|
print("=" * 60)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
try:
|
||||||
|
asyncio.run(test_tool_flow())
|
||||||
|
except AssertionError as e:
|
||||||
|
print(f"\n❌ Test failed: {e}")
|
||||||
|
sys.exit(1)
|
||||||
|
except Exception as e:
|
||||||
|
print(f"\n❌ Test error: {e}")
|
||||||
|
import traceback
|
||||||
|
traceback.print_exc()
|
||||||
|
sys.exit(1)
|
||||||
@@ -0,0 +1,183 @@
|
|||||||
|
"""Integration test for tool execution in chat completions.
|
||||||
|
|
||||||
|
This test verifies that:
|
||||||
|
1. Tools are properly parsed from model output
|
||||||
|
2. Tools are executed and results fed back to model
|
||||||
|
3. The loop continues generating until final response
|
||||||
|
"""
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
import json
|
||||||
|
import sys
|
||||||
|
import os
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..', 'src'))
|
||||||
|
|
||||||
|
from api.models import ChatMessage
|
||||||
|
from api.chat_handlers import handle_chat_completion, _sanitize_tools
|
||||||
|
from api.tool_parser import parse_tool_calls
|
||||||
|
from api.formatting import format_messages_with_tools
|
||||||
|
|
||||||
|
|
||||||
|
class MockSwarm:
|
||||||
|
"""Mock swarm manager for testing."""
|
||||||
|
|
||||||
|
async def generate(self, prompt, max_tokens, temperature, use_consensus):
|
||||||
|
"""Generate a mock response."""
|
||||||
|
# Return different responses based on prompt content
|
||||||
|
if "tool_result" in prompt.lower():
|
||||||
|
# Final response after tool execution
|
||||||
|
return MockResponse("Here's the result: The tool was executed successfully!")
|
||||||
|
else:
|
||||||
|
# First response with tool call
|
||||||
|
return MockResponse("TOOL: bash\nARGUMENTS: {\"command\": \"echo test\"}")
|
||||||
|
|
||||||
|
|
||||||
|
class MockResponse:
|
||||||
|
"""Mock generation result."""
|
||||||
|
|
||||||
|
def __init__(self, text):
|
||||||
|
self.selected_response = MockSelectedResponse(text)
|
||||||
|
|
||||||
|
|
||||||
|
class MockSelectedResponse:
|
||||||
|
"""Mock selected response."""
|
||||||
|
|
||||||
|
def __init__(self, text):
|
||||||
|
self.text = text
|
||||||
|
self.tokens_generated = 50
|
||||||
|
self.tokens_per_second = 10.0
|
||||||
|
|
||||||
|
|
||||||
|
class MockExecutor:
|
||||||
|
"""Mock tool executor."""
|
||||||
|
|
||||||
|
async def execute_tool(self, tool_name, tool_args, working_dir=None):
|
||||||
|
"""Execute a tool mock."""
|
||||||
|
return f"Mock result from {tool_name} with args {tool_args}"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_tool_execution_loop():
|
||||||
|
"""Test that tools are executed and loop continues."""
|
||||||
|
print("Testing tool execution loop...")
|
||||||
|
|
||||||
|
# Create a mock request
|
||||||
|
request = ChatMessage(
|
||||||
|
role="user",
|
||||||
|
content="Run echo test"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Wrap in request object
|
||||||
|
from api.models import ChatCompletionRequest
|
||||||
|
req = ChatCompletionRequest(
|
||||||
|
model="test-model",
|
||||||
|
messages=[request],
|
||||||
|
tools=None,
|
||||||
|
max_tokens=1024,
|
||||||
|
temperature=0.7
|
||||||
|
)
|
||||||
|
|
||||||
|
# Create mock swarm
|
||||||
|
swarm = MockSwarm()
|
||||||
|
|
||||||
|
# We can't easily test the full handler without a real tool executor,
|
||||||
|
# so let's test the key parts
|
||||||
|
|
||||||
|
# Test 1: Verify tool parsing works
|
||||||
|
print(" Test 1: Tool parsing")
|
||||||
|
tool_text = 'TOOL: bash\nARGUMENTS: {"command": "echo test"}'
|
||||||
|
content, tool_calls = parse_tool_calls(tool_text)
|
||||||
|
|
||||||
|
assert tool_calls is not None, "Tool calls should be parsed"
|
||||||
|
assert len(tool_calls) == 1, "Should parse one tool call"
|
||||||
|
assert tool_calls[0]["function"]["name"] == "bash", "Tool name should be bash"
|
||||||
|
assert "echo test" in tool_calls[0]["function"]["arguments"], "Command should be in arguments"
|
||||||
|
print(" ✓ Tool parsing works correctly")
|
||||||
|
|
||||||
|
# Test 2: Verify tool instructions are loaded
|
||||||
|
print(" Test 2: Tool instructions")
|
||||||
|
instructions = format_messages_with_tools([request], None)
|
||||||
|
assert len(instructions) > 0, "Instructions should be generated"
|
||||||
|
assert "tool" in instructions.lower(), "Instructions should mention tools"
|
||||||
|
print(" ✓ Tool instructions are loaded")
|
||||||
|
|
||||||
|
# Test 3: Verify multiple tool calls can be parsed
|
||||||
|
print(" Test 3: Multiple tool calls")
|
||||||
|
multi_tool = '''TOOL: bash
|
||||||
|
ARGUMENTS: {"command": "ls"}
|
||||||
|
|
||||||
|
TOOL: write
|
||||||
|
ARGUMENTS: {"filePath": "test.txt", "content": "hello"}'''
|
||||||
|
content, tool_calls = parse_tool_calls(multi_tool)
|
||||||
|
assert tool_calls is not None, "Multiple tools should be parsed"
|
||||||
|
assert len(tool_calls) == 2, "Should parse two tool calls"
|
||||||
|
assert tool_calls[0]["function"]["name"] == "bash", "First tool should be bash"
|
||||||
|
assert tool_calls[1]["function"]["name"] == "write", "Second tool should be write"
|
||||||
|
print(" ✓ Multiple tool calls parsed correctly")
|
||||||
|
|
||||||
|
# Test 4: Verify tool sanitization
|
||||||
|
print(" Test 4: Tool sanitization")
|
||||||
|
# Create mock tool with invalid 'description' in properties
|
||||||
|
from api.models import Tool, FunctionDefinition
|
||||||
|
mock_tool = Tool(
|
||||||
|
type="function",
|
||||||
|
function=FunctionDefinition(
|
||||||
|
name="test_tool",
|
||||||
|
description="Test tool",
|
||||||
|
parameters={
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"description": "Invalid field",
|
||||||
|
"param1": {"type": "string"}
|
||||||
|
},
|
||||||
|
"required": ["description", "param1"]
|
||||||
|
}
|
||||||
|
)
|
||||||
|
)
|
||||||
|
sanitized = _sanitize_tools([mock_tool])
|
||||||
|
assert len(sanitized) == 1, "Should return one tool"
|
||||||
|
assert "description" not in sanitized[0].function.parameters.get("properties", {}), \
|
||||||
|
"Should remove invalid 'description' from properties"
|
||||||
|
print(" ✓ Tool sanitization removes invalid fields")
|
||||||
|
|
||||||
|
print("\n✅ All tool execution loop tests passed!")
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_no_tool_parsing():
|
||||||
|
"""Test that normal responses without tools work."""
|
||||||
|
print("\nTesting response without tools...")
|
||||||
|
|
||||||
|
# Test normal response
|
||||||
|
normal_text = "This is a normal response without any tool calls."
|
||||||
|
content, tool_calls = parse_tool_calls(normal_text)
|
||||||
|
|
||||||
|
assert tool_calls is None, "No tool calls should be found"
|
||||||
|
assert content == normal_text, "Content should be returned unchanged"
|
||||||
|
print(" ✓ Normal responses pass through without modification")
|
||||||
|
|
||||||
|
print("\n✅ No-tool parsing test passed!")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
async def run_tests():
|
||||||
|
try:
|
||||||
|
await test_tool_execution_loop()
|
||||||
|
await test_no_tool_parsing()
|
||||||
|
print("\n" + "=" * 60)
|
||||||
|
print("All integration tests passed!")
|
||||||
|
print("=" * 60)
|
||||||
|
except AssertionError as e:
|
||||||
|
print(f"\n❌ Test failed: {e}")
|
||||||
|
import traceback
|
||||||
|
traceback.print_exc()
|
||||||
|
sys.exit(1)
|
||||||
|
except Exception as e:
|
||||||
|
print(f"\n❌ Test error: {e}")
|
||||||
|
import traceback
|
||||||
|
traceback.print_exc()
|
||||||
|
sys.exit(1)
|
||||||
|
|
||||||
|
asyncio.run(run_tests())
|
||||||
@@ -0,0 +1,105 @@
|
|||||||
|
"""Test to verify tool execution is triggered when model generates tool calls."""
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
import sys
|
||||||
|
import os
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..', 'src'))
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_tool_execution_triggered():
|
||||||
|
"""Verify that tool execution is properly triggered."""
|
||||||
|
|
||||||
|
from api.models import ChatMessage, ChatCompletionRequest
|
||||||
|
from api.chat_handlers import handle_chat_completion
|
||||||
|
from api.tool_parser import parse_tool_calls
|
||||||
|
from tools.executor import ToolExecutor, set_tool_executor
|
||||||
|
|
||||||
|
print("=" * 60)
|
||||||
|
print("Tool Execution Trigger Test")
|
||||||
|
print("=" * 60)
|
||||||
|
|
||||||
|
# Create a mock swarm that generates a tool call
|
||||||
|
class MockSwarm:
|
||||||
|
async def generate(self, prompt, max_tokens, temperature, use_consensus):
|
||||||
|
# First call: generate tool call
|
||||||
|
if "user" in prompt and "echo hello" in prompt:
|
||||||
|
return MockResult("TOOL: bash\nARGUMENTS: {\"command\": \"echo hello\"}")
|
||||||
|
# Second call: after tool result, generate answer
|
||||||
|
elif "tool" in prompt.lower():
|
||||||
|
return MockResult("Output: hello\nThe command executed successfully!")
|
||||||
|
else:
|
||||||
|
return MockResult("I don't understand")
|
||||||
|
|
||||||
|
class MockResult:
|
||||||
|
def __init__(self, text):
|
||||||
|
self.selected_response = MockSelectedResponse(text)
|
||||||
|
|
||||||
|
class MockSelectedResponse:
|
||||||
|
def __init__(self, text):
|
||||||
|
self.text = text
|
||||||
|
self.tokens_generated = 20
|
||||||
|
self.tokens_per_second = 5.0
|
||||||
|
|
||||||
|
# Set up tool executor
|
||||||
|
executor = ToolExecutor(tool_host_url=None)
|
||||||
|
set_tool_executor(executor)
|
||||||
|
|
||||||
|
# Create request
|
||||||
|
request = ChatCompletionRequest(
|
||||||
|
model="test-model",
|
||||||
|
messages=[ChatMessage(role="user", content="echo hello")],
|
||||||
|
tools=None, # No explicit tools - should still parse from response
|
||||||
|
max_tokens=1024,
|
||||||
|
temperature=0.7
|
||||||
|
)
|
||||||
|
|
||||||
|
print("\n1. Testing that tool calls are parsed...")
|
||||||
|
model_response = "TOOL: bash\nARGUMENTS: {\"command\": \"echo hello\"}"
|
||||||
|
content, tool_calls = parse_tool_calls(model_response)
|
||||||
|
|
||||||
|
assert tool_calls is not None, "Tool calls should be parsed from response"
|
||||||
|
assert len(tool_calls) == 1, "Should have one tool call"
|
||||||
|
print(f" ✓ Tool call parsed: {tool_calls[0]['function']['name']}")
|
||||||
|
|
||||||
|
print("\n2. Verifying tool executor is set...")
|
||||||
|
from tools.executor import get_tool_executor
|
||||||
|
current_executor = get_tool_executor()
|
||||||
|
assert current_executor is not None, "Tool executor should be set"
|
||||||
|
print(f" ✓ Tool executor configured: {current_executor.tool_host_url or 'local'}")
|
||||||
|
|
||||||
|
print("\n3. Testing tool execution...")
|
||||||
|
# Try to execute the tool
|
||||||
|
try:
|
||||||
|
from api.routes import execute_tool_server_side
|
||||||
|
result = await execute_tool_server_side(
|
||||||
|
"bash",
|
||||||
|
{"command": "echo hello"},
|
||||||
|
working_dir=None
|
||||||
|
)
|
||||||
|
print(f" ✓ Tool executed successfully")
|
||||||
|
print(f" ✓ Result: {result[:50]}..." if len(result) > 50 else f" ✓ Result: {result}")
|
||||||
|
except Exception as e:
|
||||||
|
print(f" ✗ Tool execution failed: {e}")
|
||||||
|
raise
|
||||||
|
|
||||||
|
print("\n" + "=" * 60)
|
||||||
|
print("All tool execution trigger tests passed!")
|
||||||
|
print("=" * 60)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
try:
|
||||||
|
asyncio.run(test_tool_execution_triggered())
|
||||||
|
except AssertionError as e:
|
||||||
|
print(f"\n❌ Test failed: {e}")
|
||||||
|
import traceback
|
||||||
|
traceback.print_exc()
|
||||||
|
sys.exit(1)
|
||||||
|
except Exception as e:
|
||||||
|
print(f"\n❌ Test error: {e}")
|
||||||
|
import traceback
|
||||||
|
traceback.print_exc()
|
||||||
|
sys.exit(1)
|
||||||
Reference in New Issue
Block a user