LogicGoInfotechSpaces's picture
Add project files
0ed44c0
"""
Simple test script for the Nano Banana Image Edit API
"""
import requests
import sys
import os
API_BASE_URL = "http://localhost:8000"
def test_upload_image(image_path):
"""Test uploading an image"""
print(f"\n1. Uploading image: {image_path}")
if not os.path.exists(image_path):
print(f"Error: Image file not found: {image_path}")
return None
with open(image_path, "rb") as f:
response = requests.post(
f"{API_BASE_URL}/upload",
files={"file": f}
)
if response.status_code == 200:
data = response.json()
print(f"βœ“ Image uploaded successfully!")
print(f" Image ID: {data['image_id']}")
return data['image_id']
else:
print(f"βœ— Upload failed: {response.status_code}")
print(f" {response.text}")
return None
def test_edit_image(image_id, prompt):
"""Test editing an image"""
print(f"\n2. Editing image with prompt: '{prompt}'")
response = requests.post(
f"{API_BASE_URL}/edit",
data={
"image_id": image_id,
"prompt": prompt
}
)
if response.status_code == 200:
data = response.json()
print(f"βœ“ Image edited successfully!")
print(f" Task ID: {data['task_id']}")
print(f" Status: {data['status']}")
return data['task_id']
else:
print(f"βœ— Edit failed: {response.status_code}")
print(f" {response.text}")
return None
def test_get_result(task_id):
"""Test getting result"""
print(f"\n3. Getting result for task: {task_id}")
response = requests.get(f"{API_BASE_URL}/result/{task_id}")
if response.status_code == 200:
data = response.json()
print(f"βœ“ Result retrieved!")
print(f" Status: {data['status']}")
if data.get('result_image_id'):
print(f" Result Image ID: {data['result_image_id']}")
print(f" Result URL: {data['result_image_url']}")
return data
else:
print(f"βœ— Get result failed: {response.status_code}")
print(f" {response.text}")
return None
def test_download_image(result_image_id, output_path):
"""Test downloading the edited image"""
print(f"\n4. Downloading edited image...")
response = requests.get(f"{API_BASE_URL}/result/image/{result_image_id}")
if response.status_code == 200:
with open(output_path, "wb") as f:
f.write(response.content)
print(f"βœ“ Image downloaded to: {output_path}")
return True
else:
print(f"βœ— Download failed: {response.status_code}")
print(f" {response.text}")
return False
def test_health():
"""Test health endpoint"""
print("Testing API health...")
try:
response = requests.get(f"{API_BASE_URL}/health")
if response.status_code == 200:
data = response.json()
print(f"βœ“ API is healthy")
print(f" Model available: {data.get('model_available', False)}")
print(f" Model loaded: {data.get('model_loaded', False)}")
return True
else:
print(f"βœ— Health check failed: {response.status_code}")
return False
except requests.exceptions.ConnectionError:
print(f"βœ— Cannot connect to API at {API_BASE_URL}")
print(" Make sure the API server is running: python api.py")
return False
if __name__ == "__main__":
# Check health first
if not test_health():
sys.exit(1)
# Get image path from command line or use default
if len(sys.argv) > 1:
image_path = sys.argv[1]
else:
print("Usage: python test_api.py <image_path> [prompt]")
print("Example: python test_api.py test_image.jpg 'make the sky blue'")
sys.exit(1)
prompt = sys.argv[2] if len(sys.argv) > 2 else "enhance the image"
# Run tests
image_id = test_upload_image(image_path)
if not image_id:
sys.exit(1)
task_id = test_edit_image(image_id, prompt)
if not task_id:
sys.exit(1)
result = test_get_result(task_id)
if not result or not result.get('result_image_id'):
sys.exit(1)
result_image_id = result['result_image_id']
output_path = f"edited_{os.path.basename(image_path)}"
test_download_image(result_image_id, output_path)
print("\nβœ“ All tests completed!")