skyscript's picture
Create app.py
27568a2 verified
import io
import os
import base64
from typing import List, Optional, Union, Dict, Any
import gradio as gr
import numpy as np
from PIL import Image
import requests
import json
class ImageAgent:
def __init__(self, model: str='gpt-image-1', api_key: str=None):
self.model = model
self.origin = 'bj'
self.api_key = api_key or os.getenv('API_KEY')
self.gen_url = f'https://gpt-{self.origin}.singularity-ai.com/gpt-proxy/azure/imagen'
self.edit_url = f'https://gpt-{self.origin}.singularity-ai.com/gpt-proxy/azure/imagen/edit'
def image_generate(self, data: dict = None, prompt: str = None, retries: int = 3, **kwargs):
headers = {
"app_key": self.api_key,
"Content-Type": "application/json"
}
request_data = data or {
"model": self.model or "gpt-image-1",
"prompt": prompt or "a red fox in a snowy forest",
"size": kwargs.get("size", "auto"),
"quality": kwargs.get("quality", "high"),
"n": kwargs.get("n", 1)
}
print("***** Request Data - Image Generate *****")
print(json.dumps(request_data, indent=2))
for i in range(retries):
try:
print(f"第 {i+1} 次发送图像生成请求...")
response = requests.post(self.gen_url, json=request_data, headers=headers, stream=True)
print(f"响应状态码: {response.status_code}")
if response.status_code != 200:
raise Exception(f"请求失败:{response.text}")
try:
response_json = json.loads(response.text)
except json.JSONDecodeError:
raise Exception(f"响应内容无法解析为 JSON:{response.text}")
if not response_json.get("data"):
raise Exception(f"响应内容 data 字段为空,准备重试...")
return response_json
except Exception as e:
raise Exception(f"发生错误:{e}")
raise Exception(f"重试超过{retries}次,图像生成失败!")
def image_edit(self, image: tuple, data: dict = None, prompt: str = None, retries: int = 3, **kwargs):
request_data = data or {
"model": self.model or "gpt-image-1",
"prompt": prompt,
"size": kwargs.get("size", "auto"),
"quality": kwargs.get("quality", "high"),
"n": kwargs.get("n", 1)
}
print("***** Request Data - Image Edit *****")
print(json.dumps(request_data, indent=2))
for i in range(retries):
try:
print(f"第 {i+1} 次发送图像编辑请求...")
headers = {
"app_key": self.api_key
}
response = requests.post(self.edit_url, headers=headers, files={"image": image}, data=request_data, timeout=180)
print(f"响应状态码: {response.status_code}")
if response.status_code != 200:
raise Exception(f"请求失败:{response.text}")
try:
response_json = json.loads(response.text)
except json.JSONDecodeError:
raise Exception(f"响应内容无法解析为JSON: {response.text}")
if not response_json.get("data"):
raise Exception(f"响应内容 data 字段为空: {response.text}, 准备重试...")
return response_json
except Exception as e:
raise Exception(f"发生错误:{e}")
raise Exception(f"重试超过{retries}次,图像编辑失败!")
# --- Constants ---
MODEL = "gpt-image-1"
SIZE_CHOICES = ["auto", "1024x1024", "1536x1024", "1024x1536"]
QUALITY_CHOICES = ["auto", "low", "medium", "high"]
FORMAT_CHOICES = ["png"]
def _client(key: str) -> ImageAgent:
"""Initializes the Image Agent with the provided API key."""
api_key = key.strip() or os.getenv("API_KEY", "")
if not api_key:
raise gr.Error("Please enter your API key")
return ImageAgent(api_key=api_key)
def _img_list(resp: Dict[str, Any]) -> List[Union[np.ndarray, str]]:
"""
Decode base64 images into numpy arrays (for Gradio) or pass URL strings directly.
"""
imgs: List[Union[np.ndarray, str]] = []
for d in resp.get("data", []):
if d.get("b64_json", None):
data = base64.b64decode(d.get("b64_json"))
img = Image.open(io.BytesIO(data))
imgs.append(np.array(img))
elif d.get("url", None):
imgs.append(d.get("url"))
return imgs
def _common_kwargs(
prompt: Optional[str],
n: int,
size: str,
quality: str,
) -> Dict[str, Any]:
"""Prepare keyword args for Images API."""
kwargs: Dict[str, Any] = {
"model": MODEL,
"n": n,
}
if size != "auto":
kwargs["size"] = size
if quality != "auto":
kwargs["quality"] = quality
if prompt is not None:
kwargs["prompt"] = prompt
return kwargs
def convert_to_format(
img_array: np.ndarray,
target_fmt: str,
quality: int = 75,
) -> np.ndarray:
"""
Convert a PIL numpy array to target_fmt (JPEG/WebP) and return as numpy array.
"""
img = Image.fromarray(img_array.astype(np.uint8))
buf = io.BytesIO()
img.save(buf, format=target_fmt.upper(), quality=quality)
buf.seek(0)
img2 = Image.open(buf)
return np.array(img2)
# ---------- Generate ---------- #
def generate(
api_key: str,
prompt: str,
n: int,
size: str,
quality: str,
):
if not prompt:
raise gr.Error("Please enter a prompt.")
try:
agent = _client(api_key)
common_args = _common_kwargs(prompt, n, size, quality)
api_kwargs = {"retries": 3, **common_args}
resp = agent.image_generate(**api_kwargs)
imgs = _img_list(resp)
# if out_fmt in {"jpeg", "webp"}:
# imgs = [convert_to_format(img, out_fmt) for img in imgs]
return imgs
except Exception as e:
raise gr.Error(str(e))
# ---------- Edit ---------- #
def _bytes_from_numpy(arr: np.ndarray) -> bytes:
img = Image.fromarray(arr.astype(np.uint8))
buf = io.BytesIO()
img.save(buf, format="PNG")
return buf.getvalue()
def edit_image(
api_key: str,
image_numpy: Optional[np.ndarray],
prompt: str,
n: int,
size: str,
quality: str
):
if image_numpy is None:
raise gr.Error("Please upload an image.")
if not prompt:
raise gr.Error("Please enter an edit prompt.")
img_bytes = _bytes_from_numpy(image_numpy)
try:
agent = _client(api_key)
common_args = _common_kwargs(prompt, n, size, quality)
image_tuple = ("image.png", img_bytes, "image/png")
api_kwargs = {"image": image_tuple, "retries": 3, **common_args}
resp = agent.image_edit(**api_kwargs)
imgs = _img_list(resp)
# if out_fmt in {"jpeg", "webp"}:
# imgs = [convert_to_format(img, out_fmt) for img in imgs]
return imgs
except Exception as e:
raise gr.Error(str(e))
# ---------- UI ---------- #
def build_ui():
with gr.Blocks(title="GPT-Image-1 (BYOT)") as demo:
gr.Markdown("""# 🐍 GPT-Image-1 Playground""")
with gr.Accordion("🔐 API key", open=False):
api = gr.Textbox(label="OpenAI API key", type="password", placeholder="gpt-...")
with gr.Row():
n_slider = gr.Slider(1, 4, value=1, step=1, label="Number of images (n)")
size = gr.Dropdown(SIZE_CHOICES, value="auto", label="Size")
quality = gr.Dropdown(QUALITY_CHOICES, value="auto", label="Quality")
with gr.Row():
out_fmt = gr.Radio(FORMAT_CHOICES, value="png", label="Output Format")
common_controls = [n_slider, size, quality]
with gr.Tabs():
with gr.TabItem("Generate"):
prompt_gen = gr.Textbox(
label="Prompt",
lines=3,
placeholder="Write down your prompt here",
autofocus=True,
container=False
)
btn_gen = gr.Button("Generate 🚀")
gallery_gen = gr.Gallery(columns=2, height="auto")
btn_gen.click(
generate,
inputs=[api, prompt_gen] + common_controls,
outputs=gallery_gen
)
with gr.TabItem("Edit"):
img_edit = gr.Image(type="numpy", label="Image to edit", height=400)
prompt_edit = gr.Textbox(label="Edit prompt", lines=2, placeholder="Write down your prompt here")
btn_edit = gr.Button("Edit 🖌️")
gallery_edit = gr.Gallery(columns=2, height="auto")
btn_edit.click(edit_image, inputs=[api, img_edit, prompt_edit] + common_controls, outputs=gallery_edit)
return demo
if __name__ == "__main__":
app = build_ui()
app.launch(share=os.getenv("GRADIO_SHARE") == "true", debug=os.getenv("GRADIO_DEBUG") == "true")