Spaces:
Running
Running
import io | |
import os | |
import base64 | |
from typing import List, Optional, Union, Dict, Any | |
import gradio as gr | |
import numpy as np | |
from PIL import Image | |
import requests | |
import json | |
class ImageAgent: | |
def __init__(self, model: str='gpt-image-1', api_key: str=None): | |
self.model = model | |
self.origin = 'bj' | |
self.api_key = api_key or os.getenv('API_KEY') | |
self.gen_url = f'https://gpt-{self.origin}.singularity-ai.com/gpt-proxy/azure/imagen' | |
self.edit_url = f'https://gpt-{self.origin}.singularity-ai.com/gpt-proxy/azure/imagen/edit' | |
def image_generate(self, data: dict = None, prompt: str = None, retries: int = 3, **kwargs): | |
headers = { | |
"app_key": self.api_key, | |
"Content-Type": "application/json" | |
} | |
request_data = data or { | |
"model": self.model or "gpt-image-1", | |
"prompt": prompt or "a red fox in a snowy forest", | |
"size": kwargs.get("size", "auto"), | |
"quality": kwargs.get("quality", "high"), | |
"n": kwargs.get("n", 1) | |
} | |
print("***** Request Data - Image Generate *****") | |
print(json.dumps(request_data, indent=2)) | |
for i in range(retries): | |
try: | |
print(f"第 {i+1} 次发送图像生成请求...") | |
response = requests.post(self.gen_url, json=request_data, headers=headers, stream=True) | |
print(f"响应状态码: {response.status_code}") | |
if response.status_code != 200: | |
raise Exception(f"请求失败:{response.text}") | |
try: | |
response_json = json.loads(response.text) | |
except json.JSONDecodeError: | |
raise Exception(f"响应内容无法解析为 JSON:{response.text}") | |
if not response_json.get("data"): | |
raise Exception(f"响应内容 data 字段为空,准备重试...") | |
return response_json | |
except Exception as e: | |
raise Exception(f"发生错误:{e}") | |
raise Exception(f"重试超过{retries}次,图像生成失败!") | |
def image_edit(self, image: tuple, data: dict = None, prompt: str = None, retries: int = 3, **kwargs): | |
request_data = data or { | |
"model": self.model or "gpt-image-1", | |
"prompt": prompt, | |
"size": kwargs.get("size", "auto"), | |
"quality": kwargs.get("quality", "high"), | |
"n": kwargs.get("n", 1) | |
} | |
print("***** Request Data - Image Edit *****") | |
print(json.dumps(request_data, indent=2)) | |
for i in range(retries): | |
try: | |
print(f"第 {i+1} 次发送图像编辑请求...") | |
headers = { | |
"app_key": self.api_key | |
} | |
response = requests.post(self.edit_url, headers=headers, files={"image": image}, data=request_data, timeout=180) | |
print(f"响应状态码: {response.status_code}") | |
if response.status_code != 200: | |
raise Exception(f"请求失败:{response.text}") | |
try: | |
response_json = json.loads(response.text) | |
except json.JSONDecodeError: | |
raise Exception(f"响应内容无法解析为JSON: {response.text}") | |
if not response_json.get("data"): | |
raise Exception(f"响应内容 data 字段为空: {response.text}, 准备重试...") | |
return response_json | |
except Exception as e: | |
raise Exception(f"发生错误:{e}") | |
raise Exception(f"重试超过{retries}次,图像编辑失败!") | |
# --- Constants --- | |
MODEL = "gpt-image-1" | |
SIZE_CHOICES = ["auto", "1024x1024", "1536x1024", "1024x1536"] | |
QUALITY_CHOICES = ["auto", "low", "medium", "high"] | |
FORMAT_CHOICES = ["png"] | |
def _client(key: str) -> ImageAgent: | |
"""Initializes the Image Agent with the provided API key.""" | |
api_key = key.strip() or os.getenv("API_KEY", "") | |
if not api_key: | |
raise gr.Error("Please enter your API key") | |
return ImageAgent(api_key=api_key) | |
def _img_list(resp: Dict[str, Any]) -> List[Union[np.ndarray, str]]: | |
""" | |
Decode base64 images into numpy arrays (for Gradio) or pass URL strings directly. | |
""" | |
imgs: List[Union[np.ndarray, str]] = [] | |
for d in resp.get("data", []): | |
if d.get("b64_json", None): | |
data = base64.b64decode(d.get("b64_json")) | |
img = Image.open(io.BytesIO(data)) | |
imgs.append(np.array(img)) | |
elif d.get("url", None): | |
imgs.append(d.get("url")) | |
return imgs | |
def _common_kwargs( | |
prompt: Optional[str], | |
n: int, | |
size: str, | |
quality: str, | |
) -> Dict[str, Any]: | |
"""Prepare keyword args for Images API.""" | |
kwargs: Dict[str, Any] = { | |
"model": MODEL, | |
"n": n, | |
} | |
if size != "auto": | |
kwargs["size"] = size | |
if quality != "auto": | |
kwargs["quality"] = quality | |
if prompt is not None: | |
kwargs["prompt"] = prompt | |
return kwargs | |
def convert_to_format( | |
img_array: np.ndarray, | |
target_fmt: str, | |
quality: int = 75, | |
) -> np.ndarray: | |
""" | |
Convert a PIL numpy array to target_fmt (JPEG/WebP) and return as numpy array. | |
""" | |
img = Image.fromarray(img_array.astype(np.uint8)) | |
buf = io.BytesIO() | |
img.save(buf, format=target_fmt.upper(), quality=quality) | |
buf.seek(0) | |
img2 = Image.open(buf) | |
return np.array(img2) | |
# ---------- Generate ---------- # | |
def generate( | |
api_key: str, | |
prompt: str, | |
n: int, | |
size: str, | |
quality: str, | |
): | |
if not prompt: | |
raise gr.Error("Please enter a prompt.") | |
try: | |
agent = _client(api_key) | |
common_args = _common_kwargs(prompt, n, size, quality) | |
api_kwargs = {"retries": 3, **common_args} | |
resp = agent.image_generate(**api_kwargs) | |
imgs = _img_list(resp) | |
# if out_fmt in {"jpeg", "webp"}: | |
# imgs = [convert_to_format(img, out_fmt) for img in imgs] | |
return imgs | |
except Exception as e: | |
raise gr.Error(str(e)) | |
# ---------- Edit ---------- # | |
def _bytes_from_numpy(arr: np.ndarray) -> bytes: | |
img = Image.fromarray(arr.astype(np.uint8)) | |
buf = io.BytesIO() | |
img.save(buf, format="PNG") | |
return buf.getvalue() | |
def edit_image( | |
api_key: str, | |
image_numpy: Optional[np.ndarray], | |
prompt: str, | |
n: int, | |
size: str, | |
quality: str | |
): | |
if image_numpy is None: | |
raise gr.Error("Please upload an image.") | |
if not prompt: | |
raise gr.Error("Please enter an edit prompt.") | |
img_bytes = _bytes_from_numpy(image_numpy) | |
try: | |
agent = _client(api_key) | |
common_args = _common_kwargs(prompt, n, size, quality) | |
image_tuple = ("image.png", img_bytes, "image/png") | |
api_kwargs = {"image": image_tuple, "retries": 3, **common_args} | |
resp = agent.image_edit(**api_kwargs) | |
imgs = _img_list(resp) | |
# if out_fmt in {"jpeg", "webp"}: | |
# imgs = [convert_to_format(img, out_fmt) for img in imgs] | |
return imgs | |
except Exception as e: | |
raise gr.Error(str(e)) | |
# ---------- UI ---------- # | |
def build_ui(): | |
with gr.Blocks(title="GPT-Image-1 (BYOT)") as demo: | |
gr.Markdown("""# 🐍 GPT-Image-1 Playground""") | |
with gr.Accordion("🔐 API key", open=False): | |
api = gr.Textbox(label="OpenAI API key", type="password", placeholder="gpt-...") | |
with gr.Row(): | |
n_slider = gr.Slider(1, 4, value=1, step=1, label="Number of images (n)") | |
size = gr.Dropdown(SIZE_CHOICES, value="auto", label="Size") | |
quality = gr.Dropdown(QUALITY_CHOICES, value="auto", label="Quality") | |
with gr.Row(): | |
out_fmt = gr.Radio(FORMAT_CHOICES, value="png", label="Output Format") | |
common_controls = [n_slider, size, quality] | |
with gr.Tabs(): | |
with gr.TabItem("Generate"): | |
prompt_gen = gr.Textbox( | |
label="Prompt", | |
lines=3, | |
placeholder="Write down your prompt here", | |
autofocus=True, | |
container=False | |
) | |
btn_gen = gr.Button("Generate 🚀") | |
gallery_gen = gr.Gallery(columns=2, height="auto") | |
btn_gen.click( | |
generate, | |
inputs=[api, prompt_gen] + common_controls, | |
outputs=gallery_gen | |
) | |
with gr.TabItem("Edit"): | |
img_edit = gr.Image(type="numpy", label="Image to edit", height=400) | |
prompt_edit = gr.Textbox(label="Edit prompt", lines=2, placeholder="Write down your prompt here") | |
btn_edit = gr.Button("Edit 🖌️") | |
gallery_edit = gr.Gallery(columns=2, height="auto") | |
btn_edit.click(edit_image, inputs=[api, img_edit, prompt_edit] + common_controls, outputs=gallery_edit) | |
return demo | |
if __name__ == "__main__": | |
app = build_ui() | |
app.launch(share=os.getenv("GRADIO_SHARE") == "true", debug=os.getenv("GRADIO_DEBUG") == "true") | |