diff --git a/src/api/chat_handlers.py b/src/api/chat_handlers.py index a0bd51a..7330990 100644 --- a/src/api/chat_handlers.py +++ b/src/api/chat_handlers.py @@ -228,25 +228,48 @@ def _create_response( return response -async def _generate_with_local_swarm( - swarm_manager, +async def _generate_with_consensus( prompt: str, max_tokens: int, temperature: float, - stream: bool = False + swarm_manager, + federated_swarm=None ) -> tuple[str, int, float]: - """Generate response using local swarm. + """Generate response with consensus (local or federated). + + This is the unified generation interface - it handles both local-only + and federated generation transparently. Callers don't need to know + which mode is being used. Args: - swarm_manager: Swarm manager instance prompt: Prompt to generate from max_tokens: Maximum tokens to generate temperature: Sampling temperature - stream: Whether this is a streaming request + swarm_manager: Local swarm manager instance + federated_swarm: Optional federated swarm for multi-node consensus Returns: Tuple of (response_text, tokens_generated, tokens_per_second) """ + # Check if federation is available + if federated_swarm is not None: + peers = federated_swarm.discovery.get_peers() + if peers: + logger.debug(f"🌐 Using federation with {len(peers)} peer(s)") + try: + content, tool_calls, finish_reason = await federated_swarm.generate_with_federation( + prompt=prompt, + max_tokens=max_tokens, + temperature=temperature + ) + # Federation returns content directly + # Note: tool_calls from federation should be ignored - head node handles tools + return content, 0, 0.0 # Tokens/TPS not tracked in federation mode + except Exception as e: + logger.warning(f"Federation failed, falling back to local: {e}") + # Fall through to local generation + + # Local generation (fallback or no federation) try: result = await swarm_manager.generate( prompt=prompt, @@ -254,13 +277,8 @@ async def _generate_with_local_swarm( temperature=temperature, use_consensus=True ) - response = result.selected_response - return ( - response.text, - response.tokens_generated, - response.tokens_per_second - ) + return response.text, response.tokens_generated, response.tokens_per_second except Exception as e: logger.exception("Error in swarm generation") raise @@ -544,22 +562,15 @@ async def handle_chat_completion( iteration += 1 logger.info(f"--- Tool Execution Iteration {iteration} ---") - # Generate response (use federation if available) + # Generate response (unified interface - handles federation automatically) logger.debug(f"Generating response...") - if use_federation and iteration == 1: - # First iteration: use federation for consensus - logger.info(f"🌐 Using federation for generation...") - content, tool_calls, finish_reason = await _generate_with_federation( - federated_swarm, prompt, request.max_tokens or 1024, request.temperature or 0.7 - ) - response_text = content - tokens_generated = 0 # Will be calculated from usage if needed - tps = 0.0 - else: - # Subsequent iterations or no federation: use local swarm - response_text, tokens_generated, tps = await _generate_with_local_swarm( - swarm_manager, prompt, request.max_tokens or 1024, request.temperature or 0.7 - ) + response_text, tokens_generated, tps = await _generate_with_consensus( + prompt=prompt, + max_tokens=request.max_tokens or 1024, + temperature=request.temperature or 0.7, + swarm_manager=swarm_manager, + federated_swarm=federated_swarm + ) logger.info(f"Generated response ({len(response_text)} chars, {tokens_generated} tokens)") logger.debug(f"Response: {response_text[:200]}...") @@ -685,8 +696,12 @@ async def handle_chat_completion( logger.info(f" (tool_call_id: {msg.tool_call_id}, name: {msg.name})") logger.debug(f"Full prompt:\n{next_prompt[:1000]}...") - response_text, tokens_generated, tps = await _generate_with_local_swarm( - swarm_manager, next_prompt, request.max_tokens or 1024, request.temperature or 0.7 + response_text, tokens_generated, tps = await _generate_with_consensus( + prompt=next_prompt, + max_tokens=request.max_tokens or 1024, + temperature=request.temperature or 0.7, + swarm_manager=swarm_manager, + federated_swarm=federated_swarm ) logger.info(f"✅ Generated with tool results ({len(response_text)} chars, {tokens_generated} tokens)")