skyscript commited on
Commit
27568a2
·
verified ·
1 Parent(s): d98a6c5

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +273 -0
app.py ADDED
@@ -0,0 +1,273 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import io
2
+ import os
3
+ import base64
4
+ from typing import List, Optional, Union, Dict, Any
5
+ import gradio as gr
6
+ import numpy as np
7
+ from PIL import Image
8
+ import requests
9
+ import json
10
+
11
+
12
+ class ImageAgent:
13
+ def __init__(self, model: str='gpt-image-1', api_key: str=None):
14
+ self.model = model
15
+ self.origin = 'bj'
16
+ self.api_key = api_key or os.getenv('API_KEY')
17
+ self.gen_url = f'https://gpt-{self.origin}.singularity-ai.com/gpt-proxy/azure/imagen'
18
+ self.edit_url = f'https://gpt-{self.origin}.singularity-ai.com/gpt-proxy/azure/imagen/edit'
19
+
20
+ def image_generate(self, data: dict = None, prompt: str = None, retries: int = 3, **kwargs):
21
+ headers = {
22
+ "app_key": self.api_key,
23
+ "Content-Type": "application/json"
24
+ }
25
+
26
+ request_data = data or {
27
+ "model": self.model or "gpt-image-1",
28
+ "prompt": prompt or "a red fox in a snowy forest",
29
+ "size": kwargs.get("size", "auto"),
30
+ "quality": kwargs.get("quality", "high"),
31
+ "n": kwargs.get("n", 1)
32
+ }
33
+
34
+ print("***** Request Data - Image Generate *****")
35
+ print(json.dumps(request_data, indent=2))
36
+
37
+ for i in range(retries):
38
+ try:
39
+ print(f"第 {i+1} 次发送图像生成请求...")
40
+ response = requests.post(self.gen_url, json=request_data, headers=headers, stream=True)
41
+ print(f"响应状态码: {response.status_code}")
42
+
43
+ if response.status_code != 200:
44
+ raise Exception(f"请求失败:{response.text}")
45
+
46
+ try:
47
+ response_json = json.loads(response.text)
48
+ except json.JSONDecodeError:
49
+ raise Exception(f"响应内容无法解析为 JSON:{response.text}")
50
+
51
+ if not response_json.get("data"):
52
+ raise Exception(f"响应内容 data 字段为空,准备重试...")
53
+
54
+ return response_json
55
+ except Exception as e:
56
+ raise Exception(f"发生错误:{e}")
57
+
58
+ raise Exception(f"重试超过{retries}次,图像生成失败!")
59
+
60
+ def image_edit(self, image: tuple, data: dict = None, prompt: str = None, retries: int = 3, **kwargs):
61
+
62
+ request_data = data or {
63
+ "model": self.model or "gpt-image-1",
64
+ "prompt": prompt,
65
+ "size": kwargs.get("size", "auto"),
66
+ "quality": kwargs.get("quality", "high"),
67
+ "n": kwargs.get("n", 1)
68
+ }
69
+
70
+ print("***** Request Data - Image Edit *****")
71
+ print(json.dumps(request_data, indent=2))
72
+
73
+ for i in range(retries):
74
+ try:
75
+ print(f"第 {i+1} 次发送图像编辑请求...")
76
+ headers = {
77
+ "app_key": self.api_key
78
+ }
79
+ response = requests.post(self.edit_url, headers=headers, files={"image": image}, data=request_data, timeout=180)
80
+ print(f"响应状态码: {response.status_code}")
81
+
82
+ if response.status_code != 200:
83
+ raise Exception(f"请求失败:{response.text}")
84
+
85
+ try:
86
+ response_json = json.loads(response.text)
87
+ except json.JSONDecodeError:
88
+ raise Exception(f"响应内容无法解析为JSON: {response.text}")
89
+
90
+ if not response_json.get("data"):
91
+ raise Exception(f"响应内容 data 字段为空: {response.text}, 准备重试...")
92
+
93
+ return response_json
94
+ except Exception as e:
95
+ raise Exception(f"发生错误:{e}")
96
+
97
+ raise Exception(f"重试超过{retries}次,图像编辑失败!")
98
+
99
+
100
+ # --- Constants ---
101
+ MODEL = "gpt-image-1"
102
+ SIZE_CHOICES = ["auto", "1024x1024", "1536x1024", "1024x1536"]
103
+ QUALITY_CHOICES = ["auto", "low", "medium", "high"]
104
+ FORMAT_CHOICES = ["png"]
105
+
106
+
107
+ def _client(key: str) -> ImageAgent:
108
+ """Initializes the Image Agent with the provided API key."""
109
+ api_key = key.strip() or os.getenv("API_KEY", "")
110
+ if not api_key:
111
+ raise gr.Error("Please enter your API key")
112
+ return ImageAgent(api_key=api_key)
113
+
114
+
115
+ def _img_list(resp: Dict[str, Any]) -> List[Union[np.ndarray, str]]:
116
+ """
117
+ Decode base64 images into numpy arrays (for Gradio) or pass URL strings directly.
118
+ """
119
+ imgs: List[Union[np.ndarray, str]] = []
120
+ for d in resp.get("data", []):
121
+ if d.get("b64_json", None):
122
+ data = base64.b64decode(d.get("b64_json"))
123
+ img = Image.open(io.BytesIO(data))
124
+ imgs.append(np.array(img))
125
+ elif d.get("url", None):
126
+ imgs.append(d.get("url"))
127
+ return imgs
128
+
129
+
130
+ def _common_kwargs(
131
+ prompt: Optional[str],
132
+ n: int,
133
+ size: str,
134
+ quality: str,
135
+ ) -> Dict[str, Any]:
136
+ """Prepare keyword args for Images API."""
137
+ kwargs: Dict[str, Any] = {
138
+ "model": MODEL,
139
+ "n": n,
140
+ }
141
+ if size != "auto":
142
+ kwargs["size"] = size
143
+ if quality != "auto":
144
+ kwargs["quality"] = quality
145
+ if prompt is not None:
146
+ kwargs["prompt"] = prompt
147
+ return kwargs
148
+
149
+
150
+ def convert_to_format(
151
+ img_array: np.ndarray,
152
+ target_fmt: str,
153
+ quality: int = 75,
154
+ ) -> np.ndarray:
155
+ """
156
+ Convert a PIL numpy array to target_fmt (JPEG/WebP) and return as numpy array.
157
+ """
158
+ img = Image.fromarray(img_array.astype(np.uint8))
159
+ buf = io.BytesIO()
160
+ img.save(buf, format=target_fmt.upper(), quality=quality)
161
+ buf.seek(0)
162
+ img2 = Image.open(buf)
163
+ return np.array(img2)
164
+
165
+ # ---------- Generate ---------- #
166
+ def generate(
167
+ api_key: str,
168
+ prompt: str,
169
+ n: int,
170
+ size: str,
171
+ quality: str,
172
+ ):
173
+ if not prompt:
174
+ raise gr.Error("Please enter a prompt.")
175
+ try:
176
+ agent = _client(api_key)
177
+ common_args = _common_kwargs(prompt, n, size, quality)
178
+ api_kwargs = {"retries": 3, **common_args}
179
+ resp = agent.image_generate(**api_kwargs)
180
+ imgs = _img_list(resp)
181
+ # if out_fmt in {"jpeg", "webp"}:
182
+ # imgs = [convert_to_format(img, out_fmt) for img in imgs]
183
+ return imgs
184
+ except Exception as e:
185
+ raise gr.Error(str(e))
186
+
187
+
188
+ # ---------- Edit ---------- #
189
+ def _bytes_from_numpy(arr: np.ndarray) -> bytes:
190
+ img = Image.fromarray(arr.astype(np.uint8))
191
+ buf = io.BytesIO()
192
+ img.save(buf, format="PNG")
193
+ return buf.getvalue()
194
+
195
+ def edit_image(
196
+ api_key: str,
197
+ image_numpy: Optional[np.ndarray],
198
+ prompt: str,
199
+ n: int,
200
+ size: str,
201
+ quality: str
202
+ ):
203
+ if image_numpy is None:
204
+ raise gr.Error("Please upload an image.")
205
+ if not prompt:
206
+ raise gr.Error("Please enter an edit prompt.")
207
+
208
+ img_bytes = _bytes_from_numpy(image_numpy)
209
+
210
+ try:
211
+ agent = _client(api_key)
212
+ common_args = _common_kwargs(prompt, n, size, quality)
213
+
214
+ image_tuple = ("image.png", img_bytes, "image/png")
215
+ api_kwargs = {"image": image_tuple, "retries": 3, **common_args}
216
+
217
+ resp = agent.image_edit(**api_kwargs)
218
+ imgs = _img_list(resp)
219
+ # if out_fmt in {"jpeg", "webp"}:
220
+ # imgs = [convert_to_format(img, out_fmt) for img in imgs]
221
+ return imgs
222
+ except Exception as e:
223
+ raise gr.Error(str(e))
224
+
225
+
226
+ # ---------- UI ---------- #
227
+ def build_ui():
228
+ with gr.Blocks(title="GPT-Image-1 (BYOT)") as demo:
229
+ gr.Markdown("""# 🐍 GPT-Image-1 Playground""")
230
+ with gr.Accordion("🔐 API key", open=False):
231
+ api = gr.Textbox(label="OpenAI API key", type="password", placeholder="gpt-...")
232
+
233
+ with gr.Row():
234
+ n_slider = gr.Slider(1, 4, value=1, step=1, label="Number of images (n)")
235
+ size = gr.Dropdown(SIZE_CHOICES, value="auto", label="Size")
236
+ quality = gr.Dropdown(QUALITY_CHOICES, value="auto", label="Quality")
237
+ with gr.Row():
238
+ out_fmt = gr.Radio(FORMAT_CHOICES, value="png", label="Output Format")
239
+
240
+ common_controls = [n_slider, size, quality]
241
+
242
+ with gr.Tabs():
243
+ with gr.TabItem("Generate"):
244
+ prompt_gen = gr.Textbox(
245
+ label="Prompt",
246
+ lines=3,
247
+ placeholder="Write down your prompt here",
248
+ autofocus=True,
249
+ container=False
250
+ )
251
+ btn_gen = gr.Button("Generate 🚀")
252
+ gallery_gen = gr.Gallery(columns=2, height="auto")
253
+
254
+ btn_gen.click(
255
+ generate,
256
+ inputs=[api, prompt_gen] + common_controls,
257
+ outputs=gallery_gen
258
+ )
259
+
260
+ with gr.TabItem("Edit"):
261
+ img_edit = gr.Image(type="numpy", label="Image to edit", height=400)
262
+ prompt_edit = gr.Textbox(label="Edit prompt", lines=2, placeholder="Write down your prompt here")
263
+ btn_edit = gr.Button("Edit 🖌️")
264
+ gallery_edit = gr.Gallery(columns=2, height="auto")
265
+ btn_edit.click(edit_image, inputs=[api, img_edit, prompt_edit] + common_controls, outputs=gallery_edit)
266
+
267
+ return demo
268
+
269
+
270
+ if __name__ == "__main__":
271
+ app = build_ui()
272
+ app.launch(share=os.getenv("GRADIO_SHARE") == "true", debug=os.getenv("GRADIO_DEBUG") == "true")
273
+