File size: 6,700 Bytes
8bdda4d
 
6fdc459
71c9968
 
 
 
c07f100
 
 
 
aca81df
c07f100
aca81df
 
c07f100
71c9968
aca81df
c07f100
db19cd1
 
 
71c9968
c07f100
 
 
 
8492d88
aca81df
c07f100
aca81df
c07f100
aca81df
 
c07f100
8492d88
 
c07f100
 
aca81df
71c9968
8492d88
aca81df
 
 
71c9968
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
aca81df
 
71c9968
 
 
 
 
 
 
 
 
aca81df
 
 
 
 
 
 
 
 
 
 
 
 
c07f100
aca81df
 
 
8492d88
c07f100
 
aca81df
8492d88
c07f100
aca81df
c07f100
aca81df
f9484ea
6fdc459
aca81df
c07f100
 
 
 
aca81df
 
db19cd1
71c9968
 
db19cd1
aca81df
c07f100
 
 
 
 
aca81df
c07f100
db19cd1
 
aca81df
 
c07f100
 
 
f9484ea
71c9968
 
 
 
 
 
 
 
 
 
 
 
 
 
c07f100
e472f1a
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
import os
import openai
import gradio as gr
import base64
from openai import OpenAI
from PIL import Image  # Import Pillow
import io              # Import io for byte streams

# --- Configuration ---
api_key = os.getenv("OPENAI_API_KEY")
if api_key is None:
    print("Warning: OPENAI_API_KEY environment variable not set. Using default behavior.")

IMAGE_MODEL = "gpt-image-1"
# IMAGE_MODEL = "dall-e-3" # Use this if 'gpt-image-1' is incorrect

# Initialize the OpenAI client
client = None
try:
    if api_key:
        client = OpenAI(api_key=api_key)
    else:
        client = OpenAI()
except Exception as e:
    print(f"Error initializing OpenAI client: {e}")

# --- Image Generation Function ---
def generate_image(prompt_text):
    """Generates an image using the specified API model based on the prompt."""
    if not client:
        return None, "Error: OpenAI client not initialized. Check API Key or initialization."
    if not prompt_text:
        return None, "Error: Please enter a prompt."

    print(f"Generating image for prompt: '{prompt_text}' using model {IMAGE_MODEL}")
    try:
        response = client.images.generate(
            model=IMAGE_MODEL,
            prompt=prompt_text,
            n=1,
            size="1024x1024"
        )

        # --- Determine Response Type (URL or b64_json) ---
        if hasattr(response.data[0], 'url') and response.data[0].url:
            # If we get a URL, Gradio's Image component (even with type='pil')
            # can often handle downloading it directly when returned as a string.
            # However, let's be explicit and maybe download it if needed (more robust later)
            # For now, returning the URL string might work directly with type='pil' in some Gradio versions
            # but it's safer to handle it consistently if possible.
            # Let's try returning URL directly first. If it fails, we'll need to download it.
             image_result_url = response.data[0].url
             print(f"Image generated successfully (URL): {image_result_url}")
             # We need to return a PIL image or None. Returning URL string directly won't work with type='pil'.
             # Simplest fix for now if URL is received: return None and show URL in status.
             # A better fix would involve downloading the image url -> bytes -> PIL Image.
             # return None, f"Image generated (URL): {image_result_url}" # Option 1: show URL in status
             # --- Option 2: Download URL and convert to PIL (requires 'requests' library) ---
             try:
                 import requests
                 img_response = requests.get(image_result_url, stream=True)
                 img_response.raise_for_status() # Raise error for bad responses (4xx or 5xx)
                 img_bytes = img_response.content
                 pil_image = Image.open(io.BytesIO(img_bytes))
                 return pil_image, "Image generated successfully (from URL)."
             except Exception as download_err:
                 print(f"Failed to download image from URL {image_result_url}: {download_err}")
                 return None, f"Failed to download image from URL: {download_err}"
             # --- End Option 2 ---

        elif hasattr(response.data[0], 'b64_json') and response.data[0].b64_json:
             b64_data = response.data[0].b64_json
             print("Image generated successfully (base64 received).")
             # Decode base64 into bytes
             image_bytes = base64.b64decode(b64_data)
             # Create an in-memory byte stream
             byte_stream = io.BytesIO(image_bytes)
             # Open the byte stream as a PIL Image
             pil_image = Image.open(byte_stream)
             # Return the PIL image object and success message
             return pil_image, "Image generated successfully (from base64)."
        else:
             print("Error: Unexpected response format from API.")
             print(response.data[0])
             return None, "Error: Received unexpected response format from the image API."

    except openai.AuthenticationError as e:
         error_message = f"Authentication Error: Check your API key. Details: {str(e)}"
         print(error_message)
         return None, error_message
    except openai.RateLimitError as e:
         error_message = f"Rate Limit Error: You might have exceeded your quota. Details: {str(e)}"
         print(error_message)
         return None, error_message
    except openai.BadRequestError as e:
         error_message = f"Bad Request Error: {str(e)}"
         print(error_message)
         return None, error_message
    except Exception as e:
        error_message = f"An unexpected error occurred: {str(e)}"
        print(error_message)
        return None, error_message

# --- Gradio Interface ---
with gr.Blocks(title=f"AI Image Generator ({IMAGE_MODEL})") as demo:
    gr.Markdown(f"# AI Image Generator ({IMAGE_MODEL})")
    gr.Markdown(f"Generate AI-generated images using the {IMAGE_MODEL} model.")

    with gr.Row():
        with gr.Column(scale=1): # Input column
            prompt_input = gr.Textbox(label="Enter your image prompt", lines=3)
            with gr.Row():
                clear_button = gr.Button("Clear")
                submit_button = gr.Button("Generate Image", variant="primary")

        with gr.Column(scale=2): # Output column
            # --- THIS IS THE FIX ---
            # Change type to "pil"
            output_image = gr.Image(label="Generated Image", type="pil")
            # ------------------------
            status_text = gr.Textbox(label="Status", value="", interactive=False) # Show errors/status here

    # Connect the buttons to the function
    submit_button.click(
        fn=generate_image,
        inputs=prompt_input,
        outputs=[output_image, status_text] # Expects two outputs from generate_image
    )
    # Clear prompt, image (set to None), and status
    clear_button.click(fn=lambda: ["", None, ""], inputs=None, outputs=[prompt_input, output_image, status_text])

    gr.HTML('<br><p style="text-align:center;">Interface by Gradio</p>')


# --- Launch the App ---
if __name__ == "__main__":
    # Ensure Pillow is available
    try:
        from PIL import Image
        import io
    except ImportError:
        print("Error: Pillow library not found. Please install it: pip install Pillow")
        exit()
    # Ensure requests is available if handling URLs
    try:
        import requests
    except ImportError:
        print("Warning: 'requests' library not found. Won't be able to display images from URLs if the API returns a URL.")


    print("Launching Gradio Interface...")
    demo.launch(mcp_server=True)