gpt-image-mcp / app.py
DavidHoa's picture
Update app.py
e472f1a verified
import os
import openai
import gradio as gr
import base64
from openai import OpenAI
from PIL import Image # Import Pillow
import io # Import io for byte streams
# --- Configuration ---
api_key = os.getenv("OPENAI_API_KEY")
if api_key is None:
print("Warning: OPENAI_API_KEY environment variable not set. Using default behavior.")
IMAGE_MODEL = "gpt-image-1"
# IMAGE_MODEL = "dall-e-3" # Use this if 'gpt-image-1' is incorrect
# Initialize the OpenAI client
client = None
try:
if api_key:
client = OpenAI(api_key=api_key)
else:
client = OpenAI()
except Exception as e:
print(f"Error initializing OpenAI client: {e}")
# --- Image Generation Function ---
def generate_image(prompt_text):
"""Generates an image using the specified API model based on the prompt."""
if not client:
return None, "Error: OpenAI client not initialized. Check API Key or initialization."
if not prompt_text:
return None, "Error: Please enter a prompt."
print(f"Generating image for prompt: '{prompt_text}' using model {IMAGE_MODEL}")
try:
response = client.images.generate(
model=IMAGE_MODEL,
prompt=prompt_text,
n=1,
size="1024x1024"
)
# --- Determine Response Type (URL or b64_json) ---
if hasattr(response.data[0], 'url') and response.data[0].url:
# If we get a URL, Gradio's Image component (even with type='pil')
# can often handle downloading it directly when returned as a string.
# However, let's be explicit and maybe download it if needed (more robust later)
# For now, returning the URL string might work directly with type='pil' in some Gradio versions
# but it's safer to handle it consistently if possible.
# Let's try returning URL directly first. If it fails, we'll need to download it.
image_result_url = response.data[0].url
print(f"Image generated successfully (URL): {image_result_url}")
# We need to return a PIL image or None. Returning URL string directly won't work with type='pil'.
# Simplest fix for now if URL is received: return None and show URL in status.
# A better fix would involve downloading the image url -> bytes -> PIL Image.
# return None, f"Image generated (URL): {image_result_url}" # Option 1: show URL in status
# --- Option 2: Download URL and convert to PIL (requires 'requests' library) ---
try:
import requests
img_response = requests.get(image_result_url, stream=True)
img_response.raise_for_status() # Raise error for bad responses (4xx or 5xx)
img_bytes = img_response.content
pil_image = Image.open(io.BytesIO(img_bytes))
return pil_image, "Image generated successfully (from URL)."
except Exception as download_err:
print(f"Failed to download image from URL {image_result_url}: {download_err}")
return None, f"Failed to download image from URL: {download_err}"
# --- End Option 2 ---
elif hasattr(response.data[0], 'b64_json') and response.data[0].b64_json:
b64_data = response.data[0].b64_json
print("Image generated successfully (base64 received).")
# Decode base64 into bytes
image_bytes = base64.b64decode(b64_data)
# Create an in-memory byte stream
byte_stream = io.BytesIO(image_bytes)
# Open the byte stream as a PIL Image
pil_image = Image.open(byte_stream)
# Return the PIL image object and success message
return pil_image, "Image generated successfully (from base64)."
else:
print("Error: Unexpected response format from API.")
print(response.data[0])
return None, "Error: Received unexpected response format from the image API."
except openai.AuthenticationError as e:
error_message = f"Authentication Error: Check your API key. Details: {str(e)}"
print(error_message)
return None, error_message
except openai.RateLimitError as e:
error_message = f"Rate Limit Error: You might have exceeded your quota. Details: {str(e)}"
print(error_message)
return None, error_message
except openai.BadRequestError as e:
error_message = f"Bad Request Error: {str(e)}"
print(error_message)
return None, error_message
except Exception as e:
error_message = f"An unexpected error occurred: {str(e)}"
print(error_message)
return None, error_message
# --- Gradio Interface ---
with gr.Blocks(title=f"AI Image Generator ({IMAGE_MODEL})") as demo:
gr.Markdown(f"# AI Image Generator ({IMAGE_MODEL})")
gr.Markdown(f"Generate AI-generated images using the {IMAGE_MODEL} model.")
with gr.Row():
with gr.Column(scale=1): # Input column
prompt_input = gr.Textbox(label="Enter your image prompt", lines=3)
with gr.Row():
clear_button = gr.Button("Clear")
submit_button = gr.Button("Generate Image", variant="primary")
with gr.Column(scale=2): # Output column
# --- THIS IS THE FIX ---
# Change type to "pil"
output_image = gr.Image(label="Generated Image", type="pil")
# ------------------------
status_text = gr.Textbox(label="Status", value="", interactive=False) # Show errors/status here
# Connect the buttons to the function
submit_button.click(
fn=generate_image,
inputs=prompt_input,
outputs=[output_image, status_text] # Expects two outputs from generate_image
)
# Clear prompt, image (set to None), and status
clear_button.click(fn=lambda: ["", None, ""], inputs=None, outputs=[prompt_input, output_image, status_text])
gr.HTML('<br><p style="text-align:center;">Interface by Gradio</p>')
# --- Launch the App ---
if __name__ == "__main__":
# Ensure Pillow is available
try:
from PIL import Image
import io
except ImportError:
print("Error: Pillow library not found. Please install it: pip install Pillow")
exit()
# Ensure requests is available if handling URLs
try:
import requests
except ImportError:
print("Warning: 'requests' library not found. Won't be able to display images from URLs if the API returns a URL.")
print("Launching Gradio Interface...")
demo.launch(mcp_server=True)