Spaces:
Runtime error
Runtime error
File size: 6,700 Bytes
8bdda4d 6fdc459 71c9968 c07f100 aca81df c07f100 aca81df c07f100 71c9968 aca81df c07f100 db19cd1 71c9968 c07f100 8492d88 aca81df c07f100 aca81df c07f100 aca81df c07f100 8492d88 c07f100 aca81df 71c9968 8492d88 aca81df 71c9968 aca81df 71c9968 aca81df c07f100 aca81df 8492d88 c07f100 aca81df 8492d88 c07f100 aca81df c07f100 aca81df f9484ea 6fdc459 aca81df c07f100 aca81df db19cd1 71c9968 db19cd1 aca81df c07f100 aca81df c07f100 db19cd1 aca81df c07f100 f9484ea 71c9968 c07f100 e472f1a |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 |
import os
import openai
import gradio as gr
import base64
from openai import OpenAI
from PIL import Image # Import Pillow
import io # Import io for byte streams
# --- Configuration ---
api_key = os.getenv("OPENAI_API_KEY")
if api_key is None:
print("Warning: OPENAI_API_KEY environment variable not set. Using default behavior.")
IMAGE_MODEL = "gpt-image-1"
# IMAGE_MODEL = "dall-e-3" # Use this if 'gpt-image-1' is incorrect
# Initialize the OpenAI client
client = None
try:
if api_key:
client = OpenAI(api_key=api_key)
else:
client = OpenAI()
except Exception as e:
print(f"Error initializing OpenAI client: {e}")
# --- Image Generation Function ---
def generate_image(prompt_text):
"""Generates an image using the specified API model based on the prompt."""
if not client:
return None, "Error: OpenAI client not initialized. Check API Key or initialization."
if not prompt_text:
return None, "Error: Please enter a prompt."
print(f"Generating image for prompt: '{prompt_text}' using model {IMAGE_MODEL}")
try:
response = client.images.generate(
model=IMAGE_MODEL,
prompt=prompt_text,
n=1,
size="1024x1024"
)
# --- Determine Response Type (URL or b64_json) ---
if hasattr(response.data[0], 'url') and response.data[0].url:
# If we get a URL, Gradio's Image component (even with type='pil')
# can often handle downloading it directly when returned as a string.
# However, let's be explicit and maybe download it if needed (more robust later)
# For now, returning the URL string might work directly with type='pil' in some Gradio versions
# but it's safer to handle it consistently if possible.
# Let's try returning URL directly first. If it fails, we'll need to download it.
image_result_url = response.data[0].url
print(f"Image generated successfully (URL): {image_result_url}")
# We need to return a PIL image or None. Returning URL string directly won't work with type='pil'.
# Simplest fix for now if URL is received: return None and show URL in status.
# A better fix would involve downloading the image url -> bytes -> PIL Image.
# return None, f"Image generated (URL): {image_result_url}" # Option 1: show URL in status
# --- Option 2: Download URL and convert to PIL (requires 'requests' library) ---
try:
import requests
img_response = requests.get(image_result_url, stream=True)
img_response.raise_for_status() # Raise error for bad responses (4xx or 5xx)
img_bytes = img_response.content
pil_image = Image.open(io.BytesIO(img_bytes))
return pil_image, "Image generated successfully (from URL)."
except Exception as download_err:
print(f"Failed to download image from URL {image_result_url}: {download_err}")
return None, f"Failed to download image from URL: {download_err}"
# --- End Option 2 ---
elif hasattr(response.data[0], 'b64_json') and response.data[0].b64_json:
b64_data = response.data[0].b64_json
print("Image generated successfully (base64 received).")
# Decode base64 into bytes
image_bytes = base64.b64decode(b64_data)
# Create an in-memory byte stream
byte_stream = io.BytesIO(image_bytes)
# Open the byte stream as a PIL Image
pil_image = Image.open(byte_stream)
# Return the PIL image object and success message
return pil_image, "Image generated successfully (from base64)."
else:
print("Error: Unexpected response format from API.")
print(response.data[0])
return None, "Error: Received unexpected response format from the image API."
except openai.AuthenticationError as e:
error_message = f"Authentication Error: Check your API key. Details: {str(e)}"
print(error_message)
return None, error_message
except openai.RateLimitError as e:
error_message = f"Rate Limit Error: You might have exceeded your quota. Details: {str(e)}"
print(error_message)
return None, error_message
except openai.BadRequestError as e:
error_message = f"Bad Request Error: {str(e)}"
print(error_message)
return None, error_message
except Exception as e:
error_message = f"An unexpected error occurred: {str(e)}"
print(error_message)
return None, error_message
# --- Gradio Interface ---
with gr.Blocks(title=f"AI Image Generator ({IMAGE_MODEL})") as demo:
gr.Markdown(f"# AI Image Generator ({IMAGE_MODEL})")
gr.Markdown(f"Generate AI-generated images using the {IMAGE_MODEL} model.")
with gr.Row():
with gr.Column(scale=1): # Input column
prompt_input = gr.Textbox(label="Enter your image prompt", lines=3)
with gr.Row():
clear_button = gr.Button("Clear")
submit_button = gr.Button("Generate Image", variant="primary")
with gr.Column(scale=2): # Output column
# --- THIS IS THE FIX ---
# Change type to "pil"
output_image = gr.Image(label="Generated Image", type="pil")
# ------------------------
status_text = gr.Textbox(label="Status", value="", interactive=False) # Show errors/status here
# Connect the buttons to the function
submit_button.click(
fn=generate_image,
inputs=prompt_input,
outputs=[output_image, status_text] # Expects two outputs from generate_image
)
# Clear prompt, image (set to None), and status
clear_button.click(fn=lambda: ["", None, ""], inputs=None, outputs=[prompt_input, output_image, status_text])
gr.HTML('<br><p style="text-align:center;">Interface by Gradio</p>')
# --- Launch the App ---
if __name__ == "__main__":
# Ensure Pillow is available
try:
from PIL import Image
import io
except ImportError:
print("Error: Pillow library not found. Please install it: pip install Pillow")
exit()
# Ensure requests is available if handling URLs
try:
import requests
except ImportError:
print("Warning: 'requests' library not found. Won't be able to display images from URLs if the API returns a URL.")
print("Launching Gradio Interface...")
demo.launch(mcp_server=True) |