refactor: unified generation interface for federation and local modes

- Created _generate_with_consensus() that handles both federation and local generation
- Callers don't need to know which mode is being used - it's transparent
- Tool execution loop uses same unified interface for all iterations
- Removed special-case federation logic from main handler
- Federation is now a transparent layer around generation
- All 41 tests passing
This commit is contained in:
2026-02-25 23:36:24 +01:00
parent 414cb444f3
commit 93844a81b0
+44 -29
View File
@@ -228,25 +228,48 @@ def _create_response(
return response return response
async def _generate_with_local_swarm( async def _generate_with_consensus(
swarm_manager,
prompt: str, prompt: str,
max_tokens: int, max_tokens: int,
temperature: float, temperature: float,
stream: bool = False swarm_manager,
federated_swarm=None
) -> tuple[str, int, float]: ) -> tuple[str, int, float]:
"""Generate response using local swarm. """Generate response with consensus (local or federated).
This is the unified generation interface - it handles both local-only
and federated generation transparently. Callers don't need to know
which mode is being used.
Args: Args:
swarm_manager: Swarm manager instance
prompt: Prompt to generate from prompt: Prompt to generate from
max_tokens: Maximum tokens to generate max_tokens: Maximum tokens to generate
temperature: Sampling temperature temperature: Sampling temperature
stream: Whether this is a streaming request swarm_manager: Local swarm manager instance
federated_swarm: Optional federated swarm for multi-node consensus
Returns: Returns:
Tuple of (response_text, tokens_generated, tokens_per_second) Tuple of (response_text, tokens_generated, tokens_per_second)
""" """
# Check if federation is available
if federated_swarm is not None:
peers = federated_swarm.discovery.get_peers()
if peers:
logger.debug(f"🌐 Using federation with {len(peers)} peer(s)")
try:
content, tool_calls, finish_reason = await federated_swarm.generate_with_federation(
prompt=prompt,
max_tokens=max_tokens,
temperature=temperature
)
# Federation returns content directly
# Note: tool_calls from federation should be ignored - head node handles tools
return content, 0, 0.0 # Tokens/TPS not tracked in federation mode
except Exception as e:
logger.warning(f"Federation failed, falling back to local: {e}")
# Fall through to local generation
# Local generation (fallback or no federation)
try: try:
result = await swarm_manager.generate( result = await swarm_manager.generate(
prompt=prompt, prompt=prompt,
@@ -254,13 +277,8 @@ async def _generate_with_local_swarm(
temperature=temperature, temperature=temperature,
use_consensus=True use_consensus=True
) )
response = result.selected_response response = result.selected_response
return ( return response.text, response.tokens_generated, response.tokens_per_second
response.text,
response.tokens_generated,
response.tokens_per_second
)
except Exception as e: except Exception as e:
logger.exception("Error in swarm generation") logger.exception("Error in swarm generation")
raise raise
@@ -544,22 +562,15 @@ async def handle_chat_completion(
iteration += 1 iteration += 1
logger.info(f"--- Tool Execution Iteration {iteration} ---") logger.info(f"--- Tool Execution Iteration {iteration} ---")
# Generate response (use federation if available) # Generate response (unified interface - handles federation automatically)
logger.debug(f"Generating response...") logger.debug(f"Generating response...")
if use_federation and iteration == 1: response_text, tokens_generated, tps = await _generate_with_consensus(
# First iteration: use federation for consensus prompt=prompt,
logger.info(f"🌐 Using federation for generation...") max_tokens=request.max_tokens or 1024,
content, tool_calls, finish_reason = await _generate_with_federation( temperature=request.temperature or 0.7,
federated_swarm, prompt, request.max_tokens or 1024, request.temperature or 0.7 swarm_manager=swarm_manager,
) federated_swarm=federated_swarm
response_text = content )
tokens_generated = 0 # Will be calculated from usage if needed
tps = 0.0
else:
# Subsequent iterations or no federation: use local swarm
response_text, tokens_generated, tps = await _generate_with_local_swarm(
swarm_manager, prompt, request.max_tokens or 1024, request.temperature or 0.7
)
logger.info(f"Generated response ({len(response_text)} chars, {tokens_generated} tokens)") logger.info(f"Generated response ({len(response_text)} chars, {tokens_generated} tokens)")
logger.debug(f"Response: {response_text[:200]}...") logger.debug(f"Response: {response_text[:200]}...")
@@ -685,8 +696,12 @@ async def handle_chat_completion(
logger.info(f" (tool_call_id: {msg.tool_call_id}, name: {msg.name})") logger.info(f" (tool_call_id: {msg.tool_call_id}, name: {msg.name})")
logger.debug(f"Full prompt:\n{next_prompt[:1000]}...") logger.debug(f"Full prompt:\n{next_prompt[:1000]}...")
response_text, tokens_generated, tps = await _generate_with_local_swarm( response_text, tokens_generated, tps = await _generate_with_consensus(
swarm_manager, next_prompt, request.max_tokens or 1024, request.temperature or 0.7 prompt=next_prompt,
max_tokens=request.max_tokens or 1024,
temperature=request.temperature or 0.7,
swarm_manager=swarm_manager,
federated_swarm=federated_swarm
) )
logger.info(f"✅ Generated with tool results ({len(response_text)} chars, {tokens_generated} tokens)") logger.info(f"✅ Generated with tool results ({len(response_text)} chars, {tokens_generated} tokens)")