refactor: unified generation interface for federation and local modes
- Created _generate_with_consensus() that handles both federation and local generation - Callers don't need to know which mode is being used - it's transparent - Tool execution loop uses same unified interface for all iterations - Removed special-case federation logic from main handler - Federation is now a transparent layer around generation - All 41 tests passing
This commit is contained in:
+44
-29
@@ -228,25 +228,48 @@ def _create_response(
|
|||||||
return response
|
return response
|
||||||
|
|
||||||
|
|
||||||
async def _generate_with_local_swarm(
|
async def _generate_with_consensus(
|
||||||
swarm_manager,
|
|
||||||
prompt: str,
|
prompt: str,
|
||||||
max_tokens: int,
|
max_tokens: int,
|
||||||
temperature: float,
|
temperature: float,
|
||||||
stream: bool = False
|
swarm_manager,
|
||||||
|
federated_swarm=None
|
||||||
) -> tuple[str, int, float]:
|
) -> tuple[str, int, float]:
|
||||||
"""Generate response using local swarm.
|
"""Generate response with consensus (local or federated).
|
||||||
|
|
||||||
|
This is the unified generation interface - it handles both local-only
|
||||||
|
and federated generation transparently. Callers don't need to know
|
||||||
|
which mode is being used.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
swarm_manager: Swarm manager instance
|
|
||||||
prompt: Prompt to generate from
|
prompt: Prompt to generate from
|
||||||
max_tokens: Maximum tokens to generate
|
max_tokens: Maximum tokens to generate
|
||||||
temperature: Sampling temperature
|
temperature: Sampling temperature
|
||||||
stream: Whether this is a streaming request
|
swarm_manager: Local swarm manager instance
|
||||||
|
federated_swarm: Optional federated swarm for multi-node consensus
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Tuple of (response_text, tokens_generated, tokens_per_second)
|
Tuple of (response_text, tokens_generated, tokens_per_second)
|
||||||
"""
|
"""
|
||||||
|
# Check if federation is available
|
||||||
|
if federated_swarm is not None:
|
||||||
|
peers = federated_swarm.discovery.get_peers()
|
||||||
|
if peers:
|
||||||
|
logger.debug(f"🌐 Using federation with {len(peers)} peer(s)")
|
||||||
|
try:
|
||||||
|
content, tool_calls, finish_reason = await federated_swarm.generate_with_federation(
|
||||||
|
prompt=prompt,
|
||||||
|
max_tokens=max_tokens,
|
||||||
|
temperature=temperature
|
||||||
|
)
|
||||||
|
# Federation returns content directly
|
||||||
|
# Note: tool_calls from federation should be ignored - head node handles tools
|
||||||
|
return content, 0, 0.0 # Tokens/TPS not tracked in federation mode
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"Federation failed, falling back to local: {e}")
|
||||||
|
# Fall through to local generation
|
||||||
|
|
||||||
|
# Local generation (fallback or no federation)
|
||||||
try:
|
try:
|
||||||
result = await swarm_manager.generate(
|
result = await swarm_manager.generate(
|
||||||
prompt=prompt,
|
prompt=prompt,
|
||||||
@@ -254,13 +277,8 @@ async def _generate_with_local_swarm(
|
|||||||
temperature=temperature,
|
temperature=temperature,
|
||||||
use_consensus=True
|
use_consensus=True
|
||||||
)
|
)
|
||||||
|
|
||||||
response = result.selected_response
|
response = result.selected_response
|
||||||
return (
|
return response.text, response.tokens_generated, response.tokens_per_second
|
||||||
response.text,
|
|
||||||
response.tokens_generated,
|
|
||||||
response.tokens_per_second
|
|
||||||
)
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.exception("Error in swarm generation")
|
logger.exception("Error in swarm generation")
|
||||||
raise
|
raise
|
||||||
@@ -544,22 +562,15 @@ async def handle_chat_completion(
|
|||||||
iteration += 1
|
iteration += 1
|
||||||
logger.info(f"--- Tool Execution Iteration {iteration} ---")
|
logger.info(f"--- Tool Execution Iteration {iteration} ---")
|
||||||
|
|
||||||
# Generate response (use federation if available)
|
# Generate response (unified interface - handles federation automatically)
|
||||||
logger.debug(f"Generating response...")
|
logger.debug(f"Generating response...")
|
||||||
if use_federation and iteration == 1:
|
response_text, tokens_generated, tps = await _generate_with_consensus(
|
||||||
# First iteration: use federation for consensus
|
prompt=prompt,
|
||||||
logger.info(f"🌐 Using federation for generation...")
|
max_tokens=request.max_tokens or 1024,
|
||||||
content, tool_calls, finish_reason = await _generate_with_federation(
|
temperature=request.temperature or 0.7,
|
||||||
federated_swarm, prompt, request.max_tokens or 1024, request.temperature or 0.7
|
swarm_manager=swarm_manager,
|
||||||
)
|
federated_swarm=federated_swarm
|
||||||
response_text = content
|
)
|
||||||
tokens_generated = 0 # Will be calculated from usage if needed
|
|
||||||
tps = 0.0
|
|
||||||
else:
|
|
||||||
# Subsequent iterations or no federation: use local swarm
|
|
||||||
response_text, tokens_generated, tps = await _generate_with_local_swarm(
|
|
||||||
swarm_manager, prompt, request.max_tokens or 1024, request.temperature or 0.7
|
|
||||||
)
|
|
||||||
|
|
||||||
logger.info(f"Generated response ({len(response_text)} chars, {tokens_generated} tokens)")
|
logger.info(f"Generated response ({len(response_text)} chars, {tokens_generated} tokens)")
|
||||||
logger.debug(f"Response: {response_text[:200]}...")
|
logger.debug(f"Response: {response_text[:200]}...")
|
||||||
@@ -685,8 +696,12 @@ async def handle_chat_completion(
|
|||||||
logger.info(f" (tool_call_id: {msg.tool_call_id}, name: {msg.name})")
|
logger.info(f" (tool_call_id: {msg.tool_call_id}, name: {msg.name})")
|
||||||
logger.debug(f"Full prompt:\n{next_prompt[:1000]}...")
|
logger.debug(f"Full prompt:\n{next_prompt[:1000]}...")
|
||||||
|
|
||||||
response_text, tokens_generated, tps = await _generate_with_local_swarm(
|
response_text, tokens_generated, tps = await _generate_with_consensus(
|
||||||
swarm_manager, next_prompt, request.max_tokens or 1024, request.temperature or 0.7
|
prompt=next_prompt,
|
||||||
|
max_tokens=request.max_tokens or 1024,
|
||||||
|
temperature=request.temperature or 0.7,
|
||||||
|
swarm_manager=swarm_manager,
|
||||||
|
federated_swarm=federated_swarm
|
||||||
)
|
)
|
||||||
|
|
||||||
logger.info(f"✅ Generated with tool results ({len(response_text)} chars, {tokens_generated} tokens)")
|
logger.info(f"✅ Generated with tool results ({len(response_text)} chars, {tokens_generated} tokens)")
|
||||||
|
|||||||
Reference in New Issue
Block a user