|
|
""" |
|
|
Simple test script for the Nano Banana Image Edit API |
|
|
""" |
|
|
import requests |
|
|
import sys |
|
|
import os |
|
|
|
|
|
API_BASE_URL = "http://localhost:8000" |
|
|
|
|
|
def test_upload_image(image_path): |
|
|
"""Test uploading an image""" |
|
|
print(f"\n1. Uploading image: {image_path}") |
|
|
|
|
|
if not os.path.exists(image_path): |
|
|
print(f"Error: Image file not found: {image_path}") |
|
|
return None |
|
|
|
|
|
with open(image_path, "rb") as f: |
|
|
response = requests.post( |
|
|
f"{API_BASE_URL}/upload", |
|
|
files={"file": f} |
|
|
) |
|
|
|
|
|
if response.status_code == 200: |
|
|
data = response.json() |
|
|
print(f"β Image uploaded successfully!") |
|
|
print(f" Image ID: {data['image_id']}") |
|
|
return data['image_id'] |
|
|
else: |
|
|
print(f"β Upload failed: {response.status_code}") |
|
|
print(f" {response.text}") |
|
|
return None |
|
|
|
|
|
def test_edit_image(image_id, prompt): |
|
|
"""Test editing an image""" |
|
|
print(f"\n2. Editing image with prompt: '{prompt}'") |
|
|
|
|
|
response = requests.post( |
|
|
f"{API_BASE_URL}/edit", |
|
|
data={ |
|
|
"image_id": image_id, |
|
|
"prompt": prompt |
|
|
} |
|
|
) |
|
|
|
|
|
if response.status_code == 200: |
|
|
data = response.json() |
|
|
print(f"β Image edited successfully!") |
|
|
print(f" Task ID: {data['task_id']}") |
|
|
print(f" Status: {data['status']}") |
|
|
return data['task_id'] |
|
|
else: |
|
|
print(f"β Edit failed: {response.status_code}") |
|
|
print(f" {response.text}") |
|
|
return None |
|
|
|
|
|
def test_get_result(task_id): |
|
|
"""Test getting result""" |
|
|
print(f"\n3. Getting result for task: {task_id}") |
|
|
|
|
|
response = requests.get(f"{API_BASE_URL}/result/{task_id}") |
|
|
|
|
|
if response.status_code == 200: |
|
|
data = response.json() |
|
|
print(f"β Result retrieved!") |
|
|
print(f" Status: {data['status']}") |
|
|
if data.get('result_image_id'): |
|
|
print(f" Result Image ID: {data['result_image_id']}") |
|
|
print(f" Result URL: {data['result_image_url']}") |
|
|
return data |
|
|
else: |
|
|
print(f"β Get result failed: {response.status_code}") |
|
|
print(f" {response.text}") |
|
|
return None |
|
|
|
|
|
def test_download_image(result_image_id, output_path): |
|
|
"""Test downloading the edited image""" |
|
|
print(f"\n4. Downloading edited image...") |
|
|
|
|
|
response = requests.get(f"{API_BASE_URL}/result/image/{result_image_id}") |
|
|
|
|
|
if response.status_code == 200: |
|
|
with open(output_path, "wb") as f: |
|
|
f.write(response.content) |
|
|
print(f"β Image downloaded to: {output_path}") |
|
|
return True |
|
|
else: |
|
|
print(f"β Download failed: {response.status_code}") |
|
|
print(f" {response.text}") |
|
|
return False |
|
|
|
|
|
def test_health(): |
|
|
"""Test health endpoint""" |
|
|
print("Testing API health...") |
|
|
try: |
|
|
response = requests.get(f"{API_BASE_URL}/health") |
|
|
if response.status_code == 200: |
|
|
data = response.json() |
|
|
print(f"β API is healthy") |
|
|
print(f" Model available: {data.get('model_available', False)}") |
|
|
print(f" Model loaded: {data.get('model_loaded', False)}") |
|
|
return True |
|
|
else: |
|
|
print(f"β Health check failed: {response.status_code}") |
|
|
return False |
|
|
except requests.exceptions.ConnectionError: |
|
|
print(f"β Cannot connect to API at {API_BASE_URL}") |
|
|
print(" Make sure the API server is running: python api.py") |
|
|
return False |
|
|
|
|
|
if __name__ == "__main__": |
|
|
|
|
|
if not test_health(): |
|
|
sys.exit(1) |
|
|
|
|
|
|
|
|
if len(sys.argv) > 1: |
|
|
image_path = sys.argv[1] |
|
|
else: |
|
|
print("Usage: python test_api.py <image_path> [prompt]") |
|
|
print("Example: python test_api.py test_image.jpg 'make the sky blue'") |
|
|
sys.exit(1) |
|
|
|
|
|
prompt = sys.argv[2] if len(sys.argv) > 2 else "enhance the image" |
|
|
|
|
|
|
|
|
image_id = test_upload_image(image_path) |
|
|
if not image_id: |
|
|
sys.exit(1) |
|
|
|
|
|
task_id = test_edit_image(image_id, prompt) |
|
|
if not task_id: |
|
|
sys.exit(1) |
|
|
|
|
|
result = test_get_result(task_id) |
|
|
if not result or not result.get('result_image_id'): |
|
|
sys.exit(1) |
|
|
|
|
|
result_image_id = result['result_image_id'] |
|
|
output_path = f"edited_{os.path.basename(image_path)}" |
|
|
test_download_image(result_image_id, output_path) |
|
|
|
|
|
print("\nβ All tests completed!") |
|
|
|
|
|
|