Spaces:
Runtime error
Runtime error
#!/usr/bin/env python3 | |
""" | |
Script de test pour l'API REST | |
""" | |
import requests | |
import json | |
import tempfile | |
import numpy as np | |
import soundfile as sf | |
import time | |
def test_api_health(base_url): | |
"""Test de l'endpoint health""" | |
print("🔍 Test de l'endpoint health...") | |
try: | |
response = requests.get(f"{base_url}/api/health") | |
if response.status_code == 200: | |
data = response.json() | |
print(f"✅ Health check réussi: {data}") | |
return True | |
else: | |
print(f"❌ Health check échoué: {response.status_code}") | |
return False | |
except Exception as e: | |
print(f"❌ Erreur health check: {e}") | |
return False | |
def test_api_info(base_url): | |
"""Test de l'endpoint racine""" | |
print("🔍 Test de l'endpoint racine...") | |
try: | |
response = requests.get(f"{base_url}/api/") | |
if response.status_code == 200: | |
data = response.json() | |
print(f"✅ Info API récupérée: {data}") | |
return True | |
else: | |
print(f"❌ Info API échoué: {response.status_code}") | |
return False | |
except Exception as e: | |
print(f"❌ Erreur info API: {e}") | |
return False | |
def create_test_audio(): | |
"""Crée un fichier audio de test""" | |
print("🎵 Création d'un fichier audio de test...") | |
# Créer un signal audio simple (1 seconde) | |
sample_rate = 16000 | |
duration = 1.0 | |
t = np.linspace(0, duration, int(sample_rate * duration)) | |
# Signal avec parole simulée (fréquences vocales) | |
audio = 0.1 * np.sin(2 * np.pi * 440 * t) + 0.05 * np.sin(2 * np.pi * 880 * t) | |
# Sauvegarder | |
test_audio_path = "test_audio_api.wav" | |
sf.write(test_audio_path, audio, sample_rate) | |
print(f"✅ Fichier audio de test créé: {test_audio_path}") | |
return test_audio_path | |
def test_audio_prediction(base_url, audio_path): | |
"""Test de l'endpoint predict avec audio""" | |
print("🔍 Test de l'endpoint predict (audio)...") | |
try: | |
with open(audio_path, 'rb') as f: | |
files = {'file': f} | |
response = requests.post(f"{base_url}/api/predict", files=files) | |
if response.status_code == 200: | |
data = response.json() | |
print(f"✅ Prédiction audio réussie:") | |
print(f" Transcription: {data.get('transcription', 'N/A')}") | |
print(f" Sentiment: {data.get('sentiment', 'N/A')}") | |
return True | |
else: | |
print(f"❌ Prédiction audio échouée: {response.status_code}") | |
print(f" Erreur: {response.text}") | |
return False | |
except Exception as e: | |
print(f"❌ Erreur prédiction audio: {e}") | |
return False | |
def test_text_prediction(base_url): | |
"""Test de l'endpoint predict_text""" | |
print("🔍 Test de l'endpoint predict_text...") | |
test_texts = [ | |
"je suis très content de ce produit", | |
"ce service est terrible", | |
"c'est neutre comme commentaire" | |
] | |
for text in test_texts: | |
try: | |
data = {"text": text} | |
response = requests.post(f"{base_url}/api/predict_text", json=data) | |
if response.status_code == 200: | |
result = response.json() | |
print(f"✅ Prédiction textuelle réussie pour '{text}':") | |
print(f" Sentiment: {result.get('sentiment', 'N/A')}") | |
else: | |
print(f"❌ Prédiction textuelle échouée pour '{text}': {response.status_code}") | |
return False | |
except Exception as e: | |
print(f"❌ Erreur prédiction textuelle: {e}") | |
return False | |
return True | |
def test_error_handling(base_url): | |
"""Test de la gestion d'erreurs""" | |
print("🔍 Test de la gestion d'erreurs...") | |
# Test avec fichier invalide | |
try: | |
with tempfile.NamedTemporaryFile(suffix='.txt', delete=False) as f: | |
f.write(b"Ceci n'est pas un fichier audio") | |
f.flush() | |
with open(f.name, 'rb') as audio_file: | |
files = {'file': audio_file} | |
response = requests.post(f"{base_url}/api/predict", files=files) | |
if response.status_code == 400: | |
print("✅ Gestion d'erreur fichier invalide: OK") | |
else: | |
print(f"❌ Gestion d'erreur fichier invalide: {response.status_code}") | |
return False | |
except Exception as e: | |
print(f"❌ Erreur test fichier invalide: {e}") | |
return False | |
# Test avec texte vide | |
try: | |
data = {"text": ""} | |
response = requests.post(f"{base_url}/api/predict_text", json=data) | |
if response.status_code in [200, 400]: | |
print("✅ Gestion d'erreur texte vide: OK") | |
else: | |
print(f"❌ Gestion d'erreur texte vide: {response.status_code}") | |
return False | |
except Exception as e: | |
print(f"❌ Erreur test texte vide: {e}") | |
return False | |
return True | |
def test_documentation(base_url): | |
"""Test de la documentation Swagger""" | |
print("🔍 Test de la documentation Swagger...") | |
try: | |
response = requests.get(f"{base_url}/api/docs") | |
if response.status_code == 200: | |
print("✅ Documentation Swagger accessible") | |
return True | |
else: | |
print(f"❌ Documentation Swagger inaccessible: {response.status_code}") | |
return False | |
except Exception as e: | |
print(f"❌ Erreur documentation Swagger: {e}") | |
return False | |
def main(): | |
"""Fonction principale de test""" | |
print("🚀 Démarrage des tests de l'API...\n") | |
# URL de base (à adapter selon votre déploiement) | |
base_url = "http://localhost:7860" # Local | |
# base_url = "https://huggingface.co/spaces/<username>/sentiment-audio-analyzer" # HF Spaces | |
tests = [ | |
("Health check", lambda: test_api_health(base_url)), | |
("Info API", lambda: test_api_info(base_url)), | |
("Documentation Swagger", lambda: test_documentation(base_url)), | |
("Gestion d'erreurs", lambda: test_error_handling(base_url)), | |
] | |
# Test avec audio (nécessite un fichier) | |
audio_path = create_test_audio() | |
tests.extend([ | |
("Prédiction audio", lambda: test_audio_prediction(base_url, audio_path)), | |
("Prédiction textuelle", lambda: test_text_prediction(base_url)), | |
]) | |
results = [] | |
for test_name, test_func in tests: | |
print(f"\n{'='*50}") | |
print(f"Test: {test_name}") | |
print('='*50) | |
try: | |
result = test_func() | |
results.append((test_name, result)) | |
except Exception as e: | |
print(f"❌ Erreur inattendue: {e}") | |
results.append((test_name, False)) | |
# Résumé | |
print(f"\n{'='*50}") | |
print("📊 RÉSUMÉ DES TESTS API") | |
print('='*50) | |
passed = 0 | |
total = len(results) | |
for test_name, result in results: | |
status = "✅ PASS" if result else "❌ FAIL" | |
print(f"{test_name}: {status}") | |
if result: | |
passed += 1 | |
print(f"\nRésultat: {passed}/{total} tests réussis") | |
if passed == total: | |
print("🎉 Tous les tests API sont passés !") | |
return True | |
else: | |
print("⚠️ Certains tests API ont échoué.") | |
return False | |
if __name__ == "__main__": | |
success = main() | |
exit(0 if success else 1) |