zenafey commited on
Commit
fe30a6a
·
1 Parent(s): 928dc00

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +156 -72
app.py CHANGED
@@ -6,11 +6,16 @@ import json
6
  import base64
7
  import os
8
  from io import BytesIO
 
9
  import PIL
 
10
  from PIL.ExifTags import TAGS
11
  import html
12
  import re
 
13
 
 
 
14
 
15
  class Prodia:
16
  def __init__(self, api_key, base=None):
@@ -18,19 +23,19 @@ class Prodia:
18
  self.headers = {
19
  "X-Prodia-Key": api_key
20
  }
21
-
22
  def generate(self, params):
23
  response = self._post(f"{self.base}/sd/generate", params)
24
  return response.json()
25
-
26
  def transform(self, params):
27
  response = self._post(f"{self.base}/sd/transform", params)
28
  return response.json()
29
-
30
  def controlnet(self, params):
31
  response = self._post(f"{self.base}/sd/controlnet", params)
32
  return response.json()
33
-
34
  def get_job(self, job_id):
35
  response = self._get(f"{self.base}/job/{job_id}")
36
  return response.json()
@@ -75,12 +80,13 @@ def image_to_base64(image_path):
75
  # Convert the image to bytes
76
  buffered = BytesIO()
77
  image.save(buffered, format="PNG") # You can change format to PNG if needed
78
-
79
  # Encode the bytes to base64
80
  img_str = base64.b64encode(buffered.getvalue())
81
 
82
  return img_str.decode('utf-8') # Convert bytes to string
83
 
 
84
  def remove_id_and_ext(text):
85
  text = re.sub(r'\[.*\]$', '', text)
86
  extension = text[-12:].strip()
@@ -90,6 +96,37 @@ def remove_id_and_ext(text):
90
  text = text[:-4]
91
  return text
92
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
93
  def get_data(text):
94
  results = {}
95
  patterns = {
@@ -97,11 +134,11 @@ def get_data(text):
97
  'negative_prompt': r'Negative prompt: (.*)',
98
  'steps': r'Steps: (\d+),',
99
  'seed': r'Seed: (\d+),',
100
- 'sampler': r'Sampler:\s*([^\s,]+(?:\s+[^\s,]+)*)',
101
  'model': r'Model:\s*([^\s,]+)',
102
  'cfg_scale': r'CFG scale:\s*([\d\.]+)',
103
  'size': r'Size:\s*([0-9]+x[0-9]+)'
104
- }
105
  for key in ['prompt', 'negative_prompt', 'steps', 'seed', 'sampler', 'model', 'cfg_scale', 'size']:
106
  match = re.search(patterns[key], text)
107
  if match:
@@ -117,23 +154,24 @@ def get_data(text):
117
  results['h'] = None
118
  return results
119
 
 
120
  def send_to_txt2img(image):
121
-
122
  result = {tabs: gr.Tabs.update(selected="t2i")}
123
 
124
  try:
125
  text = image.info['parameters']
126
  data = get_data(text)
127
  result[prompt] = gr.update(value=data['prompt'])
128
- result[negative_prompt] = gr.update(value=data['negative_prompt']) if data['negative_prompt'] is not None else gr.update()
 
129
  result[steps] = gr.update(value=int(data['steps'])) if data['steps'] is not None else gr.update()
130
  result[seed] = gr.update(value=int(data['seed'])) if data['seed'] is not None else gr.update()
131
  result[cfg_scale] = gr.update(value=float(data['cfg_scale'])) if data['cfg_scale'] is not None else gr.update()
132
  result[width] = gr.update(value=int(data['w'])) if data['w'] is not None else gr.update()
133
  result[height] = gr.update(value=int(data['h'])) if data['h'] is not None else gr.update()
134
  result[sampler] = gr.update(value=data['sampler']) if data['sampler'] is not None else gr.update()
135
- if model in model_names:
136
- result[model] = gr.update(value=model_names[model])
137
  else:
138
  result[model] = gr.update()
139
  return result
@@ -153,7 +191,6 @@ def send_to_txt2img(image):
153
  return result
154
 
155
 
156
-
157
  prodia_client = Prodia(api_key=os.getenv("PRODIA_API_KEY"))
158
  model_list = prodia_client.list_models()
159
  model_names = {}
@@ -162,8 +199,12 @@ for model_name in model_list:
162
  name_without_ext = remove_id_and_ext(model_name)
163
  model_names[name_without_ext] = model_name
164
 
165
- def flip_text(prompt, negative_prompt, model, steps, sampler, cfg_scale, width, height, seed):
166
- result = prodia_client.generate({
 
 
 
 
167
  "prompt": prompt,
168
  "negative_prompt": negative_prompt,
169
  "model": model,
@@ -173,11 +214,47 @@ def flip_text(prompt, negative_prompt, model, steps, sampler, cfg_scale, width,
173
  "width": width,
174
  "height": height,
175
  "seed": seed
176
- })
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
177
 
178
- job = prodia_client.wait(result)
 
 
 
179
 
180
- return job["imageUrl"]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
181
 
182
 
183
  css = """
@@ -187,82 +264,82 @@ css = """
187
  """
188
 
189
  with gr.Blocks(css=css) as demo:
190
-
191
-
192
  with gr.Row():
193
  with gr.Column(scale=6):
194
- model = gr.Dropdown(interactive=True,value="absolutereality_v181.safetensors [3d9d4d2b]", show_label=True, label="Stable Diffusion Checkpoint", choices=prodia_client.list_models())
195
-
 
196
  with gr.Column(scale=1):
197
- gr.Markdown(elem_id="powered-by-prodia", value="AUTOMATIC1111 Stable Diffusion Web UI.<br>Powered by [Prodia](https://prodia.com).<br> For more features and faster gen times check out our [API Docs](https://docs.prodia.com/reference/getting-started-guide)")
 
198
 
199
  with gr.Tabs() as tabs:
200
  with gr.Tab("txt2img", id='t2i'):
201
  with gr.Row():
202
  with gr.Column(scale=6, min_width=600):
203
- prompt = gr.Textbox("space warrior, beautiful, female, ultrarealistic, soft lighting, 8k", placeholder="Prompt", show_label=False, lines=3)
204
- negative_prompt = gr.Textbox(placeholder="Negative Prompt", show_label=False, lines=3, value="3d, cartoon, anime, (deformed eyes, nose, ears, nose), bad anatomy, ugly")
 
 
205
  with gr.Column():
206
  text_button = gr.Button("Generate", variant='primary', elem_id="generate")
207
-
208
  with gr.Row():
209
  with gr.Column(scale=3):
210
  with gr.Tab("Generation"):
211
  with gr.Row():
212
  with gr.Column(scale=1):
213
- sampler = gr.Dropdown(value="Euler a", show_label=True, label="Sampling Method", choices=[
214
- "Euler",
215
- "Euler a",
216
- "LMS",
217
- "Heun",
218
- "DPM2",
219
- "DPM2 a",
220
- "DPM++ 2S a",
221
- "DPM++ 2M",
222
- "DPM++ SDE",
223
- "DPM fast",
224
- "DPM adaptive",
225
- "LMS Karras",
226
- "DPM2 Karras",
227
- "DPM2 a Karras",
228
- "DPM++ 2S a Karras",
229
- "DPM++ 2M Karras",
230
- "DPM++ SDE Karras",
231
- "DDIM",
232
- "PLMS",
233
- ])
234
-
 
235
  with gr.Column(scale=1):
236
  steps = gr.Slider(label="Sampling Steps", minimum=1, maximum=30, value=25, step=1)
237
-
238
  with gr.Row():
239
  with gr.Column(scale=1):
240
  width = gr.Slider(label="Width", maximum=1024, value=512, step=8)
241
  height = gr.Slider(label="Height", maximum=1024, value=512, step=8)
242
-
243
  with gr.Column(scale=1):
244
- batch_size = gr.Slider(label="Batch Size", maximum=1, value=1)
245
- batch_count = gr.Slider(label="Batch Count", maximum=1, value=1)
246
-
247
  cfg_scale = gr.Slider(label="CFG Scale", minimum=1, maximum=20, value=7, step=1)
248
  seed = gr.Number(label="Seed", value=-1)
249
-
250
-
251
  with gr.Column(scale=2):
252
- image_output = gr.Image(value="https://images.prodia.xyz/8ede1a7c-c0ee-4ded-987d-6ffed35fc477.png")
253
-
254
- text_button.click(flip_text, inputs=[prompt, negative_prompt, model, steps, sampler, cfg_scale, width, height, seed], outputs=image_output)
255
-
256
  with gr.Tab("PNG Info"):
257
  def plaintext_to_html(text, classname=None):
258
  content = "<br>\n".join(html.escape(x) for x in text.split('\n'))
259
-
260
  return f"<p class='{classname}'>{content}</p>" if classname else f"<p>{content}</p>"
261
-
262
-
263
  def get_exif_data(image):
264
  items = image.info
265
-
266
  info = ''
267
  for key, text in items.items():
268
  info += f"""
@@ -270,25 +347,32 @@ with gr.Blocks(css=css) as demo:
270
  <p><b>{plaintext_to_html(str(key))}</b></p>
271
  <p>{plaintext_to_html(str(text))}</p>
272
  </div>
273
- """.strip()+"\n"
274
-
275
  if len(info) == 0:
276
  message = "Nothing found in the image."
277
  info = f"<div><p>{message}<p></div>"
278
-
279
  return info
280
-
281
  with gr.Row():
282
  with gr.Column():
283
  image_input = gr.Image(type="pil")
284
-
285
  with gr.Column():
286
  exif_output = gr.HTML(label="EXIF Data")
287
  send_to_txt2img_btn = gr.Button("Send to txt2img")
288
-
289
- image_input.upload(get_exif_data, inputs=[image_input], outputs=exif_output)
290
- send_to_txt2img_btn.click(send_to_txt2img, inputs=[image_input], outputs=[tabs, prompt, negative_prompt, steps, seed,
291
- model, sampler, width, height, cfg_scale])
 
 
 
 
 
 
 
292
 
293
  demo.queue(concurrency_count=32)
294
- demo.launch()
 
6
  import base64
7
  import os
8
  from io import BytesIO
9
+ import math
10
  import PIL
11
+ from PIL import Image
12
  from PIL.ExifTags import TAGS
13
  import html
14
  import re
15
+ from threading import Thread
16
 
17
+ from dotenv import load_dotenv
18
+ load_dotenv()
19
 
20
  class Prodia:
21
  def __init__(self, api_key, base=None):
 
23
  self.headers = {
24
  "X-Prodia-Key": api_key
25
  }
26
+
27
  def generate(self, params):
28
  response = self._post(f"{self.base}/sd/generate", params)
29
  return response.json()
30
+
31
  def transform(self, params):
32
  response = self._post(f"{self.base}/sd/transform", params)
33
  return response.json()
34
+
35
  def controlnet(self, params):
36
  response = self._post(f"{self.base}/sd/controlnet", params)
37
  return response.json()
38
+
39
  def get_job(self, job_id):
40
  response = self._get(f"{self.base}/job/{job_id}")
41
  return response.json()
 
80
  # Convert the image to bytes
81
  buffered = BytesIO()
82
  image.save(buffered, format="PNG") # You can change format to PNG if needed
83
+
84
  # Encode the bytes to base64
85
  img_str = base64.b64encode(buffered.getvalue())
86
 
87
  return img_str.decode('utf-8') # Convert bytes to string
88
 
89
+
90
  def remove_id_and_ext(text):
91
  text = re.sub(r'\[.*\]$', '', text)
92
  extension = text[-12:].strip()
 
96
  text = text[:-4]
97
  return text
98
 
99
+
100
+ def create_grid(image_urls):
101
+ # Download first image to get size
102
+ response = requests.get(image_urls[0])
103
+ img_data = response.content
104
+ img = Image.open(BytesIO(img_data))
105
+ w, h = img.size
106
+
107
+ # Calculate rows and cols
108
+ num_images = len(image_urls)
109
+ num_cols = min(num_images, 3)
110
+ num_rows = math.ceil(num_images / num_cols)
111
+
112
+ # Create new rgba image
113
+ grid_w = num_cols * w
114
+ grid_h = num_rows * h
115
+ grid = Image.new('RGBA', (grid_w, grid_h), (0, 0, 0, 0))
116
+
117
+ # Download images and paste into grid
118
+ for index, img_url in enumerate(image_urls):
119
+ response = requests.get(img_url)
120
+ img_data = response.content
121
+ img = Image.open(BytesIO(img_data))
122
+
123
+ row = index // num_cols
124
+ col = index % num_cols
125
+ grid.paste(img, (col * w, row * h))
126
+
127
+ # Save image
128
+ return grid
129
+
130
  def get_data(text):
131
  results = {}
132
  patterns = {
 
134
  'negative_prompt': r'Negative prompt: (.*)',
135
  'steps': r'Steps: (\d+),',
136
  'seed': r'Seed: (\d+),',
137
+ 'sampler': r'Sampler:\s*([^\s,]+(?:\s+[^\s,]+)*)',
138
  'model': r'Model:\s*([^\s,]+)',
139
  'cfg_scale': r'CFG scale:\s*([\d\.]+)',
140
  'size': r'Size:\s*([0-9]+x[0-9]+)'
141
+ }
142
  for key in ['prompt', 'negative_prompt', 'steps', 'seed', 'sampler', 'model', 'cfg_scale', 'size']:
143
  match = re.search(patterns[key], text)
144
  if match:
 
154
  results['h'] = None
155
  return results
156
 
157
+
158
  def send_to_txt2img(image):
 
159
  result = {tabs: gr.Tabs.update(selected="t2i")}
160
 
161
  try:
162
  text = image.info['parameters']
163
  data = get_data(text)
164
  result[prompt] = gr.update(value=data['prompt'])
165
+ result[negative_prompt] = gr.update(value=data['negative_prompt']) if data[
166
+ 'negative_prompt'] is not None else gr.update()
167
  result[steps] = gr.update(value=int(data['steps'])) if data['steps'] is not None else gr.update()
168
  result[seed] = gr.update(value=int(data['seed'])) if data['seed'] is not None else gr.update()
169
  result[cfg_scale] = gr.update(value=float(data['cfg_scale'])) if data['cfg_scale'] is not None else gr.update()
170
  result[width] = gr.update(value=int(data['w'])) if data['w'] is not None else gr.update()
171
  result[height] = gr.update(value=int(data['h'])) if data['h'] is not None else gr.update()
172
  result[sampler] = gr.update(value=data['sampler']) if data['sampler'] is not None else gr.update()
173
+ if data['model'] in model_names:
174
+ result[model] = gr.update(value=model_names[data['model']])
175
  else:
176
  result[model] = gr.update()
177
  return result
 
191
  return result
192
 
193
 
 
194
  prodia_client = Prodia(api_key=os.getenv("PRODIA_API_KEY"))
195
  model_list = prodia_client.list_models()
196
  model_names = {}
 
199
  name_without_ext = remove_id_and_ext(model_name)
200
  model_names[name_without_ext] = model_name
201
 
202
+
203
+
204
+
205
+ def flip_text(prompt, negative_prompt, model, steps, sampler, cfg_scale, width, height, seed, batch_size, batch_count, gallery):
206
+
207
+ data = {
208
  "prompt": prompt,
209
  "negative_prompt": negative_prompt,
210
  "model": model,
 
214
  "width": width,
215
  "height": height,
216
  "seed": seed
217
+ }
218
+
219
+ total_images = []
220
+ count_threads = []
221
+
222
+ def generate_one_grid():
223
+ grid_images = []
224
+ size_threads = []
225
+
226
+ def generate_one_image():
227
+
228
+ result = prodia_client.generate(data)
229
+
230
+ job = prodia_client.wait(result)
231
+
232
+ grid_images.append(job['imageUrl'])
233
 
234
+ for y in range(batch_size):
235
+ t = Thread(target=generate_one_image)
236
+ size_threads.append(t)
237
+ t.start()
238
 
239
+ for t in size_threads:
240
+ t.join()
241
+
242
+ total_images.append(create_grid(grid_images))
243
+
244
+ for x in range(batch_count):
245
+ t = Thread(target=generate_one_grid)
246
+ count_threads.append(t)
247
+ t.start()
248
+
249
+ for t in count_threads:
250
+ t.join()
251
+
252
+ new_images_list = [img['name'] for img in gallery]
253
+
254
+ for image in total_images:
255
+ new_images_list.insert(0, image)
256
+
257
+ return {image_output: total_images, gallery_obj: new_images_list}
258
 
259
 
260
  css = """
 
264
  """
265
 
266
  with gr.Blocks(css=css) as demo:
 
 
267
  with gr.Row():
268
  with gr.Column(scale=6):
269
+ model = gr.Dropdown(interactive=True, value="absolutereality_v181.safetensors [3d9d4d2b]", show_label=True,
270
+ label="Stable Diffusion Checkpoint", choices=prodia_client.list_models())
271
+
272
  with gr.Column(scale=1):
273
+ gr.Markdown(elem_id="powered-by-prodia",
274
+ value="AUTOMATIC1111 Stable Diffusion Web UI.<br>Powered by [Prodia](https://prodia.com).<br> For more features and faster gen times check out our [API Docs](https://docs.prodia.com/reference/getting-started-guide)")
275
 
276
  with gr.Tabs() as tabs:
277
  with gr.Tab("txt2img", id='t2i'):
278
  with gr.Row():
279
  with gr.Column(scale=6, min_width=600):
280
+ prompt = gr.Textbox("space warrior, beautiful, female, ultrarealistic, soft lighting, 8k",
281
+ placeholder="Prompt", show_label=False, lines=3)
282
+ negative_prompt = gr.Textbox(placeholder="Negative Prompt", show_label=False, lines=3,
283
+ value="3d, cartoon, anime, (deformed eyes, nose, ears, nose), bad anatomy, ugly")
284
  with gr.Column():
285
  text_button = gr.Button("Generate", variant='primary', elem_id="generate")
286
+
287
  with gr.Row():
288
  with gr.Column(scale=3):
289
  with gr.Tab("Generation"):
290
  with gr.Row():
291
  with gr.Column(scale=1):
292
+ sampler = gr.Dropdown(value="Euler a", show_label=True, label="Sampling Method",
293
+ choices=[
294
+ "Euler",
295
+ "Euler a",
296
+ "LMS",
297
+ "Heun",
298
+ "DPM2",
299
+ "DPM2 a",
300
+ "DPM++ 2S a",
301
+ "DPM++ 2M",
302
+ "DPM++ SDE",
303
+ "DPM fast",
304
+ "DPM adaptive",
305
+ "LMS Karras",
306
+ "DPM2 Karras",
307
+ "DPM2 a Karras",
308
+ "DPM++ 2S a Karras",
309
+ "DPM++ 2M Karras",
310
+ "DPM++ SDE Karras",
311
+ "DDIM",
312
+ "PLMS",
313
+ ])
314
+
315
  with gr.Column(scale=1):
316
  steps = gr.Slider(label="Sampling Steps", minimum=1, maximum=30, value=25, step=1)
317
+
318
  with gr.Row():
319
  with gr.Column(scale=1):
320
  width = gr.Slider(label="Width", maximum=1024, value=512, step=8)
321
  height = gr.Slider(label="Height", maximum=1024, value=512, step=8)
322
+
323
  with gr.Column(scale=1):
324
+ batch_size = gr.Slider(label="Batch Size", minimum=1, maximum=9, value=1, step=1)
325
+ batch_count = gr.Slider(label="Batch Count", minimum=1, maximum=100, value=1, step=1)
326
+
327
  cfg_scale = gr.Slider(label="CFG Scale", minimum=1, maximum=20, value=7, step=1)
328
  seed = gr.Number(label="Seed", value=-1)
329
+
 
330
  with gr.Column(scale=2):
331
+ image_output = gr.Gallery(value=["https://images.prodia.xyz/8ede1a7c-c0ee-4ded-987d-6ffed35fc477.png"], preview=True)
332
+
 
 
333
  with gr.Tab("PNG Info"):
334
  def plaintext_to_html(text, classname=None):
335
  content = "<br>\n".join(html.escape(x) for x in text.split('\n'))
336
+
337
  return f"<p class='{classname}'>{content}</p>" if classname else f"<p>{content}</p>"
338
+
339
+
340
  def get_exif_data(image):
341
  items = image.info
342
+
343
  info = ''
344
  for key, text in items.items():
345
  info += f"""
 
347
  <p><b>{plaintext_to_html(str(key))}</b></p>
348
  <p>{plaintext_to_html(str(text))}</p>
349
  </div>
350
+ """.strip() + "\n"
351
+
352
  if len(info) == 0:
353
  message = "Nothing found in the image."
354
  info = f"<div><p>{message}<p></div>"
355
+
356
  return info
357
+
358
  with gr.Row():
359
  with gr.Column():
360
  image_input = gr.Image(type="pil")
361
+
362
  with gr.Column():
363
  exif_output = gr.HTML(label="EXIF Data")
364
  send_to_txt2img_btn = gr.Button("Send to txt2img")
365
+
366
+ with gr.Tab("Gallery"):
367
+ gallery_obj = gr.Gallery(height=1000, columns=5)
368
+
369
+ text_button.click(flip_text,
370
+ inputs=[prompt, negative_prompt, model, steps, sampler, cfg_scale, width, height, seed, batch_size, batch_count,
371
+ gallery_obj], outputs=[image_output, gallery_obj])
372
+ image_input.upload(get_exif_data, inputs=[image_input], outputs=exif_output)
373
+ send_to_txt2img_btn.click(send_to_txt2img, inputs=[image_input],
374
+ outputs=[tabs, prompt, negative_prompt, steps, seed,
375
+ model, sampler, width, height, cfg_scale])
376
 
377
  demo.queue(concurrency_count=32)
378
+ demo.launch()