"""Test Apple Silicon MLX auto-detection and download.""" import sys import os from pathlib import Path # Add src to path sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..', 'src')) def test_apple_silicon_mlx_selection(): """Test that Apple Silicon correctly selects MLX models.""" from hardware.detector import HardwareProfile, GPUInfo from models.selector import select_optimal_model # Mock Apple Silicon hardware class MockAppleHardware: os = "darwin" cpu_cores = 12 ram_gb = 24.0 ram_available_gb = 12.0 is_apple_silicon = True has_dedicated_gpu = False gpu = GPUInfo(name="Apple Silicon GPU", vram_gb=24.0, driver_version=None) available_memory_gb = 12.0 recommended_memory_gb = 12.0 hardware = MockAppleHardware() # Test auto-detection (use_mlx=None) print("=" * 60) print("Apple Silicon MLX Auto-Detection Test") print("=" * 60) print("\n1. Testing auto-detection (use_mlx=None)...") config = select_optimal_model(hardware, use_mlx=None) assert config is not None, "Should find a model" print(f" ✓ Model selected: {config.model.name}") # Verify quantization is MLX format (4bit, 8bit, etc.) print("\n2. Verifying MLX quantization format...") is_mlx_format = 'bit' in config.quantization.name.lower() assert is_mlx_format, f"Quantization should be MLX format (4bit/8bit), got {config.quantization.name}" print(f" ✓ Quantization: {config.quantization.name} (MLX format)") # Test repository name generation print("\n3. Testing MLX repository name generation...") from models.registry import get_model_hf_repo_mlx mlx_repo = get_model_hf_repo_mlx(config.model.id, config.variant, config.quantization) assert mlx_repo is not None, "MLX repository should be generated" assert "mlx-community" in mlx_repo, "Should use mlx-community namespace" assert "-Instruct-" in mlx_repo, "Should have -Instruct- suffix" assert config.quantization.name in mlx_repo, "Should include quantization" print(f" ✓ Repository: {mlx_repo}") # Verify it's NOT using GGUF format print("\n4. Verifying NOT using GGUF format...") has_gguf = 'q4_k_m' in config.quantization.name or 'q5_k_m' in config.quantization.name has_gguf_suffix = '-GGUF' in mlx_repo assert not has_gguf, f"Should not use GGUF quantization names" assert not has_gguf_suffix, f"Should not use GGUF repository suffix" print(f" ✓ Not using GGUF format") print("\n" + "=" * 60) print("All Apple Silicon MLX tests passed!") print("=" * 60) def test_nvidia_gpu_gguf_selection(): """Test that NVIDIA GPU correctly selects GGUF models.""" from hardware.detector import HardwareProfile, GPUInfo from models.selector import select_optimal_model # Mock NVIDIA hardware class MockNvidiaHardware: os = "linux" cpu_cores = 8 ram_gb = 32.0 ram_available_gb = 20.0 is_apple_silicon = False has_dedicated_gpu = True gpu = GPUInfo(name="NVIDIA RTX 4090", vram_gb=24.0, driver_version="550.80") available_memory_gb = 20.0 recommended_memory_gb = 20.0 hardware = MockNvidiaHardware() print("\n" + "=" * 60) print("NVIDIA GPU GGUF Auto-Detection Test") print("=" * 60) print("\n1. Testing auto-detection (use_mlx=None)...") config = select_optimal_model(hardware, use_mlx=None) assert config is not None, "Should find a model" print(f" ✓ Model selected: {config.model.name}") # Verify quantization is GGUF format (q4_k_m, q5_k_m, etc.) print("\n2. Verifying GGUF quantization format...") is_gguf_format = 'q' in config.quantization.name.lower() assert is_gguf_format, f"Quantization should be GGUF format (q4_k_m/q5_k_m), got {config.quantization.name}" print(f" ✓ Quantization: {config.quantization.name} (GGUF format)") # Test repository name generation print("\n3. Testing GGUF repository name generation...") from models.registry import get_model_hf_repo gguf_repo = get_model_hf_repo(config.model.id, config.variant, config.quantization) assert gguf_repo is not None, "GGUF repository should be generated" assert "-GGUF" in gguf_repo, "Should have -GGUF suffix" print(f" ✓ Repository: {gguf_repo}") # Verify it's NOT using MLX format print("\n4. Verifying NOT using MLX format...") has_mlx_format = 'bit' in config.quantization.name.lower() and config.quantization.name not in ['q4_k_m', 'q5_k_m', 'q6_k'] has_mlx_namespace = 'mlx-community' in gguf_repo assert not has_mlx_namespace, f"Should not use mlx-community namespace" print(f" ✓ Not using MLX format") print("\n" + "=" * 60) print("All NVIDIA GPU GGUF tests passed!") print("=" * 60) if __name__ == "__main__": try: test_apple_silicon_mlx_selection() test_nvidia_gpu_gguf_selection() print("\n" + "=" * 60) print("ALL AUTO-DETECTION TESTS PASSED!") print("=" * 60) except AssertionError as e: print(f"\n❌ Test failed: {e}") sys.exit(1) except Exception as e: print(f"\n❌ Test error: {e}") import traceback traceback.print_exc() sys.exit(1)