Spaces:
Runtime error
Runtime error
import os | |
import openai | |
import gradio as gr | |
import base64 | |
from openai import OpenAI | |
from PIL import Image # Import Pillow | |
import io # Import io for byte streams | |
# --- Configuration --- | |
api_key = os.getenv("OPENAI_API_KEY") | |
if api_key is None: | |
print("Warning: OPENAI_API_KEY environment variable not set. Using default behavior.") | |
IMAGE_MODEL = "gpt-image-1" | |
# IMAGE_MODEL = "dall-e-3" # Use this if 'gpt-image-1' is incorrect | |
# Initialize the OpenAI client | |
client = None | |
try: | |
if api_key: | |
client = OpenAI(api_key=api_key) | |
else: | |
client = OpenAI() | |
except Exception as e: | |
print(f"Error initializing OpenAI client: {e}") | |
# --- Image Generation Function --- | |
def generate_image(prompt_text): | |
"""Generates an image using the specified API model based on the prompt.""" | |
if not client: | |
return None, "Error: OpenAI client not initialized. Check API Key or initialization." | |
if not prompt_text: | |
return None, "Error: Please enter a prompt." | |
print(f"Generating image for prompt: '{prompt_text}' using model {IMAGE_MODEL}") | |
try: | |
response = client.images.generate( | |
model=IMAGE_MODEL, | |
prompt=prompt_text, | |
n=1, | |
size="1024x1024" | |
) | |
# --- Determine Response Type (URL or b64_json) --- | |
if hasattr(response.data[0], 'url') and response.data[0].url: | |
# If we get a URL, Gradio's Image component (even with type='pil') | |
# can often handle downloading it directly when returned as a string. | |
# However, let's be explicit and maybe download it if needed (more robust later) | |
# For now, returning the URL string might work directly with type='pil' in some Gradio versions | |
# but it's safer to handle it consistently if possible. | |
# Let's try returning URL directly first. If it fails, we'll need to download it. | |
image_result_url = response.data[0].url | |
print(f"Image generated successfully (URL): {image_result_url}") | |
# We need to return a PIL image or None. Returning URL string directly won't work with type='pil'. | |
# Simplest fix for now if URL is received: return None and show URL in status. | |
# A better fix would involve downloading the image url -> bytes -> PIL Image. | |
# return None, f"Image generated (URL): {image_result_url}" # Option 1: show URL in status | |
# --- Option 2: Download URL and convert to PIL (requires 'requests' library) --- | |
try: | |
import requests | |
img_response = requests.get(image_result_url, stream=True) | |
img_response.raise_for_status() # Raise error for bad responses (4xx or 5xx) | |
img_bytes = img_response.content | |
pil_image = Image.open(io.BytesIO(img_bytes)) | |
return pil_image, "Image generated successfully (from URL)." | |
except Exception as download_err: | |
print(f"Failed to download image from URL {image_result_url}: {download_err}") | |
return None, f"Failed to download image from URL: {download_err}" | |
# --- End Option 2 --- | |
elif hasattr(response.data[0], 'b64_json') and response.data[0].b64_json: | |
b64_data = response.data[0].b64_json | |
print("Image generated successfully (base64 received).") | |
# Decode base64 into bytes | |
image_bytes = base64.b64decode(b64_data) | |
# Create an in-memory byte stream | |
byte_stream = io.BytesIO(image_bytes) | |
# Open the byte stream as a PIL Image | |
pil_image = Image.open(byte_stream) | |
# Return the PIL image object and success message | |
return pil_image, "Image generated successfully (from base64)." | |
else: | |
print("Error: Unexpected response format from API.") | |
print(response.data[0]) | |
return None, "Error: Received unexpected response format from the image API." | |
except openai.AuthenticationError as e: | |
error_message = f"Authentication Error: Check your API key. Details: {str(e)}" | |
print(error_message) | |
return None, error_message | |
except openai.RateLimitError as e: | |
error_message = f"Rate Limit Error: You might have exceeded your quota. Details: {str(e)}" | |
print(error_message) | |
return None, error_message | |
except openai.BadRequestError as e: | |
error_message = f"Bad Request Error: {str(e)}" | |
print(error_message) | |
return None, error_message | |
except Exception as e: | |
error_message = f"An unexpected error occurred: {str(e)}" | |
print(error_message) | |
return None, error_message | |
# --- Gradio Interface --- | |
with gr.Blocks(title=f"AI Image Generator ({IMAGE_MODEL})") as demo: | |
gr.Markdown(f"# AI Image Generator ({IMAGE_MODEL})") | |
gr.Markdown(f"Generate AI-generated images using the {IMAGE_MODEL} model.") | |
with gr.Row(): | |
with gr.Column(scale=1): # Input column | |
prompt_input = gr.Textbox(label="Enter your image prompt", lines=3) | |
with gr.Row(): | |
clear_button = gr.Button("Clear") | |
submit_button = gr.Button("Generate Image", variant="primary") | |
with gr.Column(scale=2): # Output column | |
# --- THIS IS THE FIX --- | |
# Change type to "pil" | |
output_image = gr.Image(label="Generated Image", type="pil") | |
# ------------------------ | |
status_text = gr.Textbox(label="Status", value="", interactive=False) # Show errors/status here | |
# Connect the buttons to the function | |
submit_button.click( | |
fn=generate_image, | |
inputs=prompt_input, | |
outputs=[output_image, status_text] # Expects two outputs from generate_image | |
) | |
# Clear prompt, image (set to None), and status | |
clear_button.click(fn=lambda: ["", None, ""], inputs=None, outputs=[prompt_input, output_image, status_text]) | |
gr.HTML('<br><p style="text-align:center;">Interface by Gradio</p>') | |
# --- Launch the App --- | |
if __name__ == "__main__": | |
# Ensure Pillow is available | |
try: | |
from PIL import Image | |
import io | |
except ImportError: | |
print("Error: Pillow library not found. Please install it: pip install Pillow") | |
exit() | |
# Ensure requests is available if handling URLs | |
try: | |
import requests | |
except ImportError: | |
print("Warning: 'requests' library not found. Won't be able to display images from URLs if the API returns a URL.") | |
print("Launching Gradio Interface...") | |
demo.launch(mcp_server=True) |