Spaces:
Running
Running
File size: 6,978 Bytes
373e5ff |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 |
import requests
import json
import pytest
from typing import Dict, List
import os
from datetime import datetime
import time
# Base URLs for different environments
LOCAL_URL = "http://127.0.0.1:8000"
PROD_URL = "https://theaniketgiri-synthex.hf.space"
# Use environment variable to determine which URL to use
BASE_URL = os.getenv("API_URL", LOCAL_URL)
def wait_for_model_loading(max_retries=10, delay=30):
"""Wait for model to load before running tests"""
for i in range(max_retries):
try:
response = requests.get(f"{BASE_URL}/health")
data = response.json()
print(f"\nHealth check response: {json.dumps(data, indent=2)}")
if data.get("model_loaded", False):
return True
elif data.get("model_loading", False):
print(f"Model is still loading, attempt {i+1}/{max_retries}")
else:
print(f"Model not loaded yet, attempt {i+1}/{max_retries}")
time.sleep(delay)
except Exception as e:
print(f"Error checking health: {str(e)}")
time.sleep(delay)
return False
class TestBackendAPI:
@classmethod
def setup_class(cls):
"""Setup before running tests"""
if not wait_for_model_loading():
pytest.skip("Model failed to load within timeout")
def test_health(self):
"""Test the health check endpoint"""
response = requests.get(f"{BASE_URL}/health")
assert response.status_code == 200
data = response.json()
assert "status" in data
assert data["status"] in ["healthy", "unhealthy"]
assert "timestamp" in data
assert "model_loaded" in data
print(f"\n=== Health Check ===")
print(f"Status: {data['status']}")
print(f"Model Loaded: {data['model_loaded']}")
print(f"Timestamp: {data['timestamp']}")
@pytest.mark.parametrize("record_type", [
"clinical_note",
"discharge_summary",
"lab_report",
"prescription"
])
def test_generate_single_record(self, record_type: str):
"""Test generating a single record of each type"""
url = f"{BASE_URL}/generate"
payload = {
"record_type": record_type,
"count": 1
}
print(f"\n=== Generating {record_type} ===")
response = requests.post(url, json=payload)
if response.status_code == 503:
pytest.skip("Model not loaded")
elif response.status_code == 500:
error = response.json()
pytest.fail(f"Generation failed: {error.get('detail', 'Unknown error')}")
assert response.status_code == 200
data = response.json()
assert isinstance(data, list)
assert len(data) == 1
record = data[0]
print(f"Generated Record:")
print(json.dumps(record, indent=2))
# Validate record structure
assert "type" in record
assert record["type"] == record_type
assert "content" in record
assert "generated_at" in record
def test_generate_multiple_records(self):
"""Test generating multiple records"""
url = f"{BASE_URL}/generate"
payload = {
"record_type": "clinical_note",
"count": 3
}
print("\n=== Generating Multiple Records ===")
response = requests.post(url, json=payload)
if response.status_code == 503:
pytest.skip("Model not loaded")
elif response.status_code == 500:
error = response.json()
pytest.fail(f"Generation failed: {error.get('detail', 'Unknown error')}")
assert response.status_code == 200
data = response.json()
assert isinstance(data, list)
assert len(data) == 3
print(f"Generated {len(data)} records")
for i, record in enumerate(data, 1):
print(f"\nRecord {i}:")
print(json.dumps(record, indent=2))
def test_invalid_record_type(self):
"""Test error handling for invalid record type"""
url = f"{BASE_URL}/generate"
payload = {
"record_type": "invalid_type",
"count": 1
}
print("\n=== Testing Invalid Record Type ===")
response = requests.post(url, json=payload)
assert response.status_code == 422 # FastAPI validation error
error = response.json()
assert "detail" in error
print(f"Error: {error['detail']}")
def test_invalid_count(self):
"""Test error handling for invalid count"""
url = f"{BASE_URL}/generate"
payload = {
"record_type": "clinical_note",
"count": 0
}
print("\n=== Testing Invalid Count ===")
response = requests.post(url, json=payload)
assert response.status_code == 422 # FastAPI validation error
error = response.json()
assert "detail" in error
print(f"Error: {error['detail']}")
def test_record_content_quality(self):
"""Test the quality of generated record content"""
url = f"{BASE_URL}/generate"
payload = {
"record_type": "clinical_note",
"count": 1
}
print("\n=== Testing Record Content Quality ===")
response = requests.post(url, json=payload)
if response.status_code == 503:
pytest.skip("Model not loaded")
elif response.status_code == 500:
error = response.json()
pytest.fail(f"Generation failed: {error.get('detail', 'Unknown error')}")
assert response.status_code == 200
data = response.json()
record = data[0]
# Check content length
assert len(record["content"]) > 100, "Content too short"
# Check for common medical terms
medical_terms = ["patient", "diagnosis", "treatment", "symptoms"]
content_lower = record["content"].lower()
assert any(term in content_lower for term in medical_terms), "Missing medical terminology"
print("Content Quality Checks Passed")
print(f"Content Length: {len(record['content'])} characters")
def main():
"""Run all tests"""
print("Starting API Tests...")
print(f"Testing against: {BASE_URL}")
print("=" * 50)
test_suite = TestBackendAPI()
# Run all tests
test_suite.test_health()
test_suite.test_generate_single_record("clinical_note")
test_suite.test_generate_multiple_records()
test_suite.test_invalid_record_type()
test_suite.test_invalid_count()
test_suite.test_record_content_quality()
print("\nAll tests completed successfully!")
print("=" * 50)
if __name__ == "__main__":
main() |