Spaces:
Running
on
CPU Upgrade
Running
on
CPU Upgrade
| #!/usr/bin/env python3 | |
| """ | |
| Test script for multi-model support | |
| Tests model switching and generation with CodeGen and Code-Llama | |
| """ | |
| import requests | |
| import time | |
| import sys | |
| import json | |
| BASE_URL = "http://localhost:8000" | |
| def print_header(text): | |
| """Print a formatted header""" | |
| print("\n" + "="*60) | |
| print(f" {text}") | |
| print("="*60) | |
| def print_result(success, message): | |
| """Print test result""" | |
| status = "✅ PASS" if success else "❌ FAIL" | |
| print(f"{status}: {message}") | |
| return success | |
| def test_health_check(): | |
| """Test if backend is running""" | |
| print_header("1. Health Check") | |
| try: | |
| response = requests.get(f"{BASE_URL}/health", timeout=5) | |
| data = response.json() | |
| print(f"Status: {data.get('status')}") | |
| print(f"Model loaded: {data.get('model_loaded')}") | |
| print(f"Device: {data.get('device')}") | |
| return print_result(response.status_code == 200, "Backend is running") | |
| except requests.exceptions.ConnectionError: | |
| return print_result(False, "Cannot connect to backend. Is it running?") | |
| except Exception as e: | |
| return print_result(False, f"Health check failed: {e}") | |
| def test_list_models(): | |
| """Test listing available models""" | |
| print_header("2. List Available Models") | |
| try: | |
| response = requests.get(f"{BASE_URL}/models", timeout=5) | |
| data = response.json() | |
| models = data.get('models', []) | |
| print(f"Found {len(models)} models:") | |
| for model in models: | |
| status = "✓" if model['available'] else "✗" | |
| current = " (CURRENT)" if model['is_current'] else "" | |
| print(f" {status} {model['name']} ({model['size']}) - {model['architecture']}{current}") | |
| return print_result(len(models) >= 2, f"Found {len(models)} models") | |
| except Exception as e: | |
| return print_result(False, f"List models failed: {e}") | |
| def test_current_model(): | |
| """Test getting current model info""" | |
| print_header("3. Get Current Model Info") | |
| try: | |
| response = requests.get(f"{BASE_URL}/models/current", timeout=5) | |
| data = response.json() | |
| print(f"Current model: {data.get('name')}") | |
| print(f"Model ID: {data.get('id')}") | |
| config = data.get('config', {}) | |
| print(f"Layers: {config.get('num_layers')}") | |
| print(f"Heads: {config.get('num_heads')}") | |
| print(f"Attention: {config.get('attention_type')}") | |
| return print_result(response.status_code == 200, "Got current model info") | |
| except Exception as e: | |
| return print_result(False, f"Get current model failed: {e}") | |
| def test_generation(model_name, prompt="def fibonacci(n):\n ", max_tokens=30): | |
| """Test text generation""" | |
| print_header(f"4. Test Generation with {model_name}") | |
| print(f"Prompt: {repr(prompt)}") | |
| print(f"Generating {max_tokens} tokens...") | |
| try: | |
| response = requests.post( | |
| f"{BASE_URL}/generate", | |
| json={ | |
| "prompt": prompt, | |
| "max_tokens": max_tokens, | |
| "temperature": 0.7, | |
| "extract_traces": False # Faster for testing | |
| }, | |
| timeout=60 # Generation can take a while | |
| ) | |
| if response.status_code != 200: | |
| return print_result(False, f"Generation failed: {response.status_code}") | |
| data = response.json() | |
| generated = data.get('generated_text', '') | |
| tokens = data.get('tokens', []) | |
| print(f"\nGenerated text:") | |
| print("-" * 60) | |
| print(generated) | |
| print("-" * 60) | |
| print(f"Token count: {len(tokens)}") | |
| print(f"Confidence: {data.get('confidence', 0):.3f}") | |
| print(f"Perplexity: {data.get('perplexity', 0):.3f}") | |
| return print_result(len(tokens) > 0, f"Generated {len(tokens)} tokens") | |
| except Exception as e: | |
| return print_result(False, f"Generation failed: {e}") | |
| def test_model_switch(model_id, model_name): | |
| """Test switching to a different model""" | |
| print_header(f"5. Switch to {model_name}") | |
| print(f"Switching to model: {model_id}") | |
| print("⏳ This may take a while (downloading + loading model)...") | |
| try: | |
| response = requests.post( | |
| f"{BASE_URL}/models/switch", | |
| json={"model_id": model_id}, | |
| timeout=300 # 5 minutes for download + loading | |
| ) | |
| if response.status_code != 200: | |
| return print_result(False, f"Switch failed: {response.status_code}") | |
| data = response.json() | |
| print(f"Message: {data.get('message')}") | |
| # Verify switch by getting current model | |
| verify_response = requests.get(f"{BASE_URL}/models/current", timeout=5) | |
| verify_data = verify_response.json() | |
| current_id = verify_data.get('id') | |
| success = current_id == model_id | |
| return print_result(success, f"Switched to {model_name}" if success else "Switch verification failed") | |
| except requests.exceptions.Timeout: | |
| return print_result(False, "Switch timeout - model download may be in progress") | |
| except Exception as e: | |
| return print_result(False, f"Switch failed: {e}") | |
| def test_model_info(): | |
| """Test detailed model info endpoint""" | |
| print_header("6. Get Detailed Model Info") | |
| try: | |
| response = requests.get(f"{BASE_URL}/model/info", timeout=5) | |
| data = response.json() | |
| print(f"Model: {data.get('name')}") | |
| print(f"Architecture: {data.get('architecture')}") | |
| print(f"Parameters: {data.get('totalParams'):,}") | |
| print(f"Layers: {data.get('layers')}") | |
| print(f"Heads: {data.get('heads')}") | |
| if data.get('kv_heads'): | |
| print(f"KV Heads: {data.get('kv_heads')} (GQA)") | |
| print(f"Attention type: {data.get('attention_type')}") | |
| print(f"Vocab size: {data.get('vocabSize'):,}") | |
| print(f"Context length: {data.get('maxPositions'):,}") | |
| return print_result(response.status_code == 200, "Got detailed model info") | |
| except Exception as e: | |
| return print_result(False, f"Get model info failed: {e}") | |
| def main(): | |
| """Run all tests""" | |
| print("\n🧪 Multi-Model Support Test Suite") | |
| print("This will test model switching between CodeGen 350M and Code-Llama 7B") | |
| print("\nIMPORTANT: Make sure the backend is running:") | |
| print(" cd /Users/garyboon/Development/VisualisableAI/visualisable-ai-backend") | |
| print(" python -m uvicorn backend.model_service:app --reload --port 8000") | |
| input("\nPress Enter to start tests...") | |
| results = [] | |
| # Test 1: Health check | |
| results.append(test_health_check()) | |
| if not results[-1]: | |
| print("\n❌ Backend not running. Exiting.") | |
| sys.exit(1) | |
| time.sleep(1) | |
| # Test 2: List models | |
| results.append(test_list_models()) | |
| time.sleep(1) | |
| # Test 3: Current model (should be CodeGen) | |
| results.append(test_current_model()) | |
| time.sleep(1) | |
| # Test 4: Get detailed model info | |
| results.append(test_model_info()) | |
| time.sleep(1) | |
| # Test 5: Generate with CodeGen | |
| results.append(test_generation("CodeGen 350M")) | |
| time.sleep(2) | |
| # Test 6: Switch to Code-Llama | |
| print("\n⚠️ WARNING: Next test will download Code-Llama 7B (~14GB)") | |
| print("This may take 5-10 minutes depending on your internet connection.") | |
| proceed = input("Proceed with Code-Llama test? (y/n): ").lower() | |
| if proceed == 'y': | |
| results.append(test_model_switch("code-llama-7b", "Code-Llama 7B")) | |
| if results[-1]: | |
| time.sleep(2) | |
| # Test 7: Get model info for Code-Llama | |
| results.append(test_model_info()) | |
| time.sleep(1) | |
| # Test 8: Generate with Code-Llama | |
| results.append(test_generation("Code-Llama 7B")) | |
| time.sleep(2) | |
| # Test 9: Switch back to CodeGen | |
| results.append(test_model_switch("codegen-350m", "CodeGen 350M")) | |
| if results[-1]: | |
| time.sleep(2) | |
| # Test 10: Verify CodeGen still works | |
| results.append(test_generation("CodeGen 350M (after switch back)")) | |
| else: | |
| print("\nSkipping Code-Llama tests.") | |
| # Summary | |
| print_header("Test Summary") | |
| passed = sum(results) | |
| total = len(results) | |
| print(f"Passed: {passed}/{total} tests") | |
| if passed == total: | |
| print("\n🎉 All tests passed! Multi-model support is working correctly.") | |
| return 0 | |
| else: | |
| print(f"\n⚠️ {total - passed} test(s) failed. Check output above for details.") | |
| return 1 | |
| if __name__ == "__main__": | |
| sys.exit(main()) | |