marscrazy commited on
Commit
97731f9
1 Parent(s): acde5ff

change backend

Browse files
Files changed (3) hide show
  1. README.md +5 -5
  2. app.py +108 -388
  3. ui_functions.py +6 -6
README.md CHANGED
@@ -1,13 +1,13 @@
1
  ---
2
  title: AltDiffusion
3
- emoji: 🌖
4
- colorFrom: yellow
5
- colorTo: red
6
  sdk: gradio
7
- sdk_version: 3.4
8
  app_file: app.py
9
  pinned: false
10
- license: openrail
11
  ---
12
 
13
  Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
1
  ---
2
  title: AltDiffusion
3
+ emoji: 📉
4
+ colorFrom: purple
5
+ colorTo: gray
6
  sdk: gradio
7
+ sdk_version: 3.10.1
8
  app_file: app.py
9
  pinned: false
10
+ license: creativeml-openrail-m
11
  ---
12
 
13
  Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
app.py CHANGED
@@ -1,76 +1,17 @@
1
- import imp
2
- import gradio as gr
3
  import io
4
- from PIL import Image, PngImagePlugin
 
 
 
5
  import base64
6
  import requests
7
- import json
8
  import ui_functions as uifn
9
  from css_and_js import js, call_JS
10
- import re
11
- from share_btn import community_icon_html, loading_icon_html, share_js
12
-
13
- # reload
14
- txt2img_defaults = {
15
- 'prompt': '',
16
- 'ddim_steps': 50,
17
- 'toggles': [1, 2, 3],
18
- 'sampler_name': 'k_lms',
19
- 'ddim_eta': 0.0,
20
- 'n_iter': 1,
21
- 'batch_size': 1,
22
- 'cfg_scale': 7.5,
23
- 'seed': '',
24
- 'height': 512,
25
- 'width': 512,
26
- 'fp': None,
27
- 'variant_amount': 0.0,
28
- 'variant_seed': '',
29
- 'submit_on_enter': 'Yes',
30
- }
31
-
32
- img2img_defaults = {
33
- 'prompt': '',
34
- 'ddim_steps': 50,
35
- 'toggles': [1, 4, 5],
36
- 'sampler_name': 'k_lms',
37
- 'ddim_eta': 0.0,
38
- 'n_iter': 1,
39
- 'batch_size': 1,
40
- 'cfg_scale': 5.0,
41
- 'denoising_strength': 0.75,
42
- 'mask_mode': 1,
43
- 'resize_mode': 0,
44
- 'seed': '',
45
- 'height': 512,
46
- 'width': 512,
47
- 'fp': None,
48
- }
49
- sample_img2img = None
50
- job_manager = None
51
- RealESRGAN = True
52
- show_embeddings = False
53
-
54
- img2img_resize_modes = [
55
- "Just resize",
56
- "Crop and resize",
57
- "Resize and fill",
58
- ]
59
 
60
- img2img_toggles = [
61
- 'Create prompt matrix (separate multiple prompts using |, and get all combinations of them)',
62
- 'Normalize Prompt Weights (ensure sum of weights add up to 1.0)',
63
- 'Loopback (use images from previous batch when creating next batch)',
64
- 'Random loopback seed',
65
- 'Save individual images',
66
- 'Save grid',
67
- 'Sort samples by prompt',
68
- 'Write sample info files',
69
- 'Write sample info to one file',
70
- 'jpg samples',
71
- ]
72
-
73
- img2img_toggle_defaults = [img2img_toggles[i] for i in img2img_defaults['toggles']]
74
 
75
  def read_content(file_path: str) -> str:
76
  """read the content of target file
@@ -80,13 +21,6 @@ def read_content(file_path: str) -> str:
80
 
81
  return content
82
 
83
- def base2picture(resbase64):
84
- res=resbase64.split(',')[1]
85
- img_b64decode = base64.b64decode(res)
86
- image = io.BytesIO(img_b64decode)
87
- img = Image.open(image)
88
- return img
89
-
90
  def filter_content(raw_style: str):
91
  if "(" in raw_style:
92
  i = raw_style.index("(")
@@ -98,105 +32,108 @@ def filter_content(raw_style: str):
98
  else :
99
  return raw_style[:i]
100
 
101
- def request_images(raw_text, class_draw, style_draw, batch_size, sr_option, w, h, seed):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
102
  if filter_content(class_draw) != "国画":
103
  if filter_content(class_draw) != "通用":
104
  raw_text = raw_text + f",{filter_content(class_draw)}"
105
 
106
  for sty in style_draw:
107
  raw_text = raw_text + f",{filter_content(sty)}"
108
- print(f"raw text is {raw_text}")
109
- url = "http://flagart.baai.ac.cn/api/general/"
110
  elif filter_content(class_draw) == "国画":
111
- if raw_text.endswith("国画"):
112
- pass
113
- else :
114
- raw_text = raw_text + ",国画"
115
- url = "http://flagart.baai.ac.cn/api/guohua/"
116
-
117
- d = {"data":[raw_text, int(batch_size), sr_option, w, h, seed]}
118
- # print('+++++++++right here+++++++')
119
- # print(w,h,seed)
120
- r = requests.post(url, json=d, headers={"Content-Type": "application/json", "Accept": "*/*", "Accept-Encoding": "gzip, deflate, br", "Connection": "keep-alive"})
121
- result_text = r.text
122
- content = json.loads(result_text)["data"][0]
123
- images = []
124
- for i in range(int(batch_size)):
125
- # print(content[i])
126
- images.append(base2picture(content[i]))
127
 
128
- return images
129
-
130
- def sr_request_images(img_str, idx, w, h, seed):
131
- idx_map = {
132
- "图片1(img1)":0,
133
- "图片2(img2)":1,
134
- "图片3(img3)":2,
135
- "图片4(img4)":3,
136
- }
137
- idx = idx_map[idx]
138
- # image_data = re.sub('^data:image/.+;base64,', '', img_str[idx])
139
- image_data = img_str[idx]
140
- # print(image_data)
141
- d = {"data":[image_data, 0, False, w, h, seed]} # batch_size设置为0,sr_opt设置为False, 新加了3个参数w,h,seed,随便放就行,进去不会走这个路
142
- # print(w,h,seed)
143
- url = "http://flagart.baai.ac.cn/api/general/"
144
- r = requests.post(url, json=d, headers={"Content-Type": "application/json", "Accept": "*/*", "Accept-Encoding": "gzip, deflate, br", "Connection": "keep-alive"})
145
- result_text = r.text
146
- # print(result_text)
147
- content = json.loads(result_text)["data"][0]
148
- # print(content)
149
- images = [base2picture(content[0])]
150
 
151
  return images
152
 
153
 
154
- def encode_pil_to_base64(pil_image):
155
- with io.BytesIO() as output_bytes:
156
-
157
- # Copy any text-only metadata
158
- use_metadata = False
159
- metadata = PngImagePlugin.PngInfo()
160
- for key, value in pil_image.info.items():
161
- if isinstance(key, str) and isinstance(value, str):
162
- metadata.add_text(key, value)
163
- use_metadata = True
164
-
165
- pil_image.save(
166
- output_bytes, "PNG", pnginfo=(metadata if use_metadata else None)
167
- )
168
- bytes_data = output_bytes.getvalue()
169
- base64_str = str(base64.b64encode(bytes_data), "utf-8")
170
- return "data:image/png;base64," + base64_str
171
-
172
- def img2img(*args):
173
-
174
- # 处理image
175
- for i, item in enumerate(args):
176
- # print(type(item))
177
- if type(item) == dict:
178
- args[i]['image'] = encode_pil_to_base64(item['image'])
179
- args[i]['mask'] = encode_pil_to_base64(item['mask'])
180
- # else:
181
- # print(i,type(item))
182
- # print(item)
183
-
184
- batch_size = args[8]
185
-
186
- url = "http://flagart.baai.ac.cn/api/img2img/"
187
- d = {"data":args}
188
- r = requests.post(url, json=d, headers={"Content-Type": "application/json", "Accept": "*/*", "Accept-Encoding": "gzip, deflate, br", "Connection": "keep-alive"})
189
- # print(r)
190
- result_text = r.text
191
- content = json.loads(result_text)["data"][0]
192
- images = []
193
- for i in range(batch_size):
194
- # print(content[i])
195
- images.append(base2picture(content[i]))
196
- # content = json.loads(result_text)
197
- # print(result_text)
198
- # print("服务器已经把东西返回来啦!!!!!!!乌拉乌拉!!!!!")
199
- return images
200
 
201
 
202
  examples = [
@@ -217,10 +154,6 @@ if __name__ == "__main__":
217
 
218
  with block:
219
  gr.HTML(read_content("header.html"))
220
- # with gr.Group(elem_id="share-btn-container"):
221
- # community_icon = gr.HTML(community_icon_html)
222
- # loading_icon = gr.HTML(loading_icon_html)
223
- # share_button = gr.Button("Share to community", elem_id="share-btn")
224
  with gr.Tabs(elem_id='tabss') as tabs:
225
 
226
  with gr.TabItem("文生图(Text-to-img)", id='txt2img_tab'):
@@ -272,32 +205,22 @@ if __name__ == "__main__":
272
  # interactive=True,
273
  # )
274
  sample_size = gr.Radio(choices=["1","2","3","4"], value="1", show_label=True, label='生成数量(number)')
275
- seed = gr.Number(-1, label='seed', interactive=True)
276
  with gr.Row().style(mobile_collapse=False, equal_height=True):
277
  w = gr.Slider(512,1024,value=512, step=64, label="width")
278
  h = gr.Slider(512,1024,value=512, step=64, label="height")
279
- with gr.Row(visible=False).style(mobile_collapse=False, equal_height=True):
280
- sr_option = gr.Checkbox(value=False, label="是否使用超分(Whether to use super-resolution)")
281
 
282
  gallery = gr.Gallery(
283
  label="Generated images", show_label=False, elem_id="gallery"
284
  ).style(grid=[2,2])
285
  gr.Examples(examples=examples, fn=request_images, inputs=text, outputs=gallery, examples_per_page=100)
286
  with gr.Row().style(mobile_collapse=False, equal_height=True):
287
- img_choices = gr.Dropdown(["图片1(img1)"],label='请选择一张图片发送到图生图或者超分',show_label=True,value="图片1(img1)")
288
  with gr.Row().style(mobile_collapse=False, equal_height=True):
289
  output_txt2img_copy_to_input_btn = gr.Button("发送图片到图生图(Sent the image to img2img)").style(
290
  margin=False,
291
  rounded=(True, True, True, True),
292
  )
293
- output_txt2img_sr_btn = gr.Button("将选择的图片进行超分(super-resolution)").style(
294
- margin=False,
295
- rounded=(True, True, True, True),
296
- )
297
-
298
- sr_gallery = gr.Gallery(
299
- label="SR images", show_label=True, elem_id="sr_gallery"
300
- ).style(grid=[1,1])
301
 
302
  with gr.Row():
303
  prompt = gr.Markdown("提示(Prompt):", visible=False)
@@ -308,16 +231,8 @@ if __name__ == "__main__":
308
 
309
 
310
 
311
- # 并没有什么软用,转跳还是很卡,而且只能转跳一次
312
- # go_to_img2img_btn = gr.Button("转至图生图(Go to img2img)").style(
313
- # margin=False,
314
- # rounded=(True, True, True, True),
315
- # )
316
-
317
- text.submit(request_images, inputs=[text, class_draw, style_draw, sample_size, sr_option, w, h, seed], outputs=gallery)
318
- btn.click(request_images, inputs=[text, class_draw, style_draw, sample_size, sr_option, w, h, seed], outputs=gallery)
319
- # to do here
320
- output_txt2img_sr_btn.click(sr_request_images, inputs=[gallery, img_choices, w, h, seed], outputs=[sr_gallery])
321
 
322
  sample_size.change(
323
  fn=uifn.change_img_choices,
@@ -331,8 +246,8 @@ if __name__ == "__main__":
331
  elem_id='img2img_prompt_input',
332
  placeholder="神奇的森林,流淌的河流.",
333
  lines=1,
334
- max_lines=1 if txt2img_defaults['submit_on_enter'] == 'Yes' else 25,
335
- value=img2img_defaults['prompt'],
336
  show_label=False).style()
337
 
338
  img2img_btn_mask = gr.Button("Generate", variant="primary", visible=False,
@@ -342,7 +257,7 @@ if __name__ == "__main__":
342
  with gr.Row().style(equal_height=False):
343
  #with gr.Column():
344
  img2img_image_mask = gr.Image(
345
- value=sample_img2img,
346
  source="upload",
347
  interactive=True,
348
  tool="sketch",
@@ -350,71 +265,10 @@ if __name__ == "__main__":
350
  elem_id="img2img_mask",
351
  image_mode="RGBA"
352
  )
353
- img2img_image_editor = gr.Image(
354
- value=sample_img2img,
355
- source="upload",
356
- interactive=True,
357
- tool="select",
358
- type='pil',
359
- visible=False,
360
- image_mode="RGBA",
361
- elem_id="img2img_editor"
362
- )
363
- with gr.Tabs(visible=False):
364
- with gr.TabItem("编辑设置"):
365
- with gr.Row():
366
- # disable Uncrop for now
367
- choices=["Mask", "Crop", "Uncrop"]
368
- img2img_image_editor_mode = gr.Radio(choices=["Mask"],
369
- label="编辑模式",
370
- value="Mask", elem_id='edit_mode_select',
371
- visible=True)
372
- img2img_mask = gr.Radio(choices=["保留mask区域", "生成mask区域"],
373
- label="Mask 方式",
374
- #value=img2img_mask_modes[img2img_defaults['mask_mode']],
375
- value = "生成mask区域",
376
- visible=True)
377
-
378
- img2img_mask_blur_strength = gr.Slider(minimum=1, maximum=10, step=1,
379
- label="How much blurry should the mask be? (to avoid hard edges)",
380
- value=3, visible=False)
381
-
382
- img2img_resize = gr.Radio(label="Resize mode",
383
- choices=["Just resize", "Crop and resize",
384
- "Resize and fill"],
385
- value=img2img_resize_modes[
386
- img2img_defaults['resize_mode']], visible=False)
387
-
388
- img2img_painterro_btn = gr.Button("Advanced Editor",visible=False)
389
- # with gr.TabItem("Hints",visible=False):
390
- # img2img_help = gr.Markdown(visible=False, value=uifn.help_text)
391
  gr.Markdown('#### 编辑后的图片')
392
  with gr.Row():
393
  output_img2img_gallery = gr.Gallery(label="Images", elem_id="img2img_gallery_output").style(
394
  grid=[4,4,4] )
395
- img2img_job_ui = job_manager.draw_gradio_ui() if job_manager else None
396
- with gr.Column(visible=False):
397
- with gr.Tabs(visible=False):
398
- with gr.TabItem("", id="img2img_actions_tab",visible=False):
399
- gr.Markdown("Select an image, then press one of the buttons below")
400
- with gr.Row():
401
- output_img2img_copy_to_clipboard_btn = gr.Button("Copy to clipboard")
402
- output_img2img_copy_to_input_btn = gr.Button("Push to img2img input")
403
- output_img2img_copy_to_mask_btn = gr.Button("Push to img2img input mask")
404
-
405
- gr.Markdown("Warning: This will clear your current image and mask settings!")
406
- with gr.TabItem("", id="img2img_output_info_tab",visible=False):
407
- output_img2img_params = gr.Textbox(label="Generation parameters")
408
- with gr.Row():
409
- output_img2img_copy_params = gr.Button("Copy full parameters").click(
410
- inputs=output_img2img_params, outputs=[],
411
- _js='(x) => {navigator.clipboard.writeText(x.replace(": ",":"))}', fn=None,
412
- show_progress=False)
413
- output_img2img_seed = gr.Number(label='Seed', interactive=False, visible=False)
414
- output_img2img_copy_seed = gr.Button("Copy only seed").click(
415
- inputs=output_img2img_seed, outputs=[],
416
- _js=call_JS("gradioInputToClipboard"), fn=None, show_progress=False)
417
- output_img2img_stats = gr.HTML(label='Stats')
418
  with gr.Row():
419
  gr.Markdown('提示(prompt):')
420
  with gr.Row():
@@ -423,129 +277,18 @@ if __name__ == "__main__":
423
  gr.Markdown('Please select an image to cover up a part of the area and enter a text description.')
424
  gr.Markdown('# 编辑设置',visible=False)
425
 
426
- with gr.Row(visible=False):
427
- with gr.Column():
428
- img2img_width = gr.Slider(minimum=64, maximum=2048, step=64, label="图片宽度",
429
- value=img2img_defaults["width"])
430
- img2img_height = gr.Slider(minimum=64, maximum=2048, step=64, label="图片高度",
431
- value=img2img_defaults["height"])
432
- img2img_cfg = gr.Slider(minimum=-40.0, maximum=30.0, step=0.5,
433
- label='文本引导强度',
434
- value=img2img_defaults['cfg_scale'], elem_id='cfg_slider')
435
- img2img_seed = gr.Textbox(label="随机种子", lines=1, max_lines=1,
436
- value=img2img_defaults["seed"])
437
- img2img_batch_count = gr.Slider(minimum=1, maximum=50, step=1,
438
- label='生成数量',
439
- value=img2img_defaults['n_iter'])
440
- img2img_dimensions_info_text_box = gr.Textbox(
441
- label="长宽比设置")
442
- with gr.Column():
443
- img2img_steps = gr.Slider(minimum=1, maximum=250, step=1, label="采样步数",
444
- value=img2img_defaults['ddim_steps'])
445
-
446
- img2img_sampling = gr.Dropdown(label='采样方式',
447
- choices=["DDIM", 'k_dpm_2_a', 'k_dpm_2', 'k_euler_a', 'k_euler',
448
- 'k_heun', 'k_lms'],
449
- value=img2img_defaults['sampler_name'])
450
-
451
- img2img_denoising = gr.Slider(minimum=0.0, maximum=1.0, step=0.01, label='Denoising Strength',
452
- value=img2img_defaults['denoising_strength'],visible=False)
453
-
454
- img2img_toggles = gr.CheckboxGroup(label='', choices=img2img_toggles,
455
- value=img2img_toggle_defaults,visible=False)
456
-
457
- img2img_realesrgan_model_name = gr.Dropdown(label='RealESRGAN model',
458
- choices=['RealESRGAN_x4plus',
459
- 'RealESRGAN_x4plus_anime_6B'],
460
- value='RealESRGAN_x4plus',
461
- visible=RealESRGAN is not None) # TODO: Feels like I shouldnt slot it in here.
462
-
463
- img2img_embeddings = gr.File(label="Embeddings file for textual inversion",
464
- visible=show_embeddings)
465
-
466
- img2img_image_editor_mode.change(
467
- uifn.change_image_editor_mode,
468
- [img2img_image_editor_mode,
469
- img2img_image_editor,
470
- img2img_image_mask,
471
- img2img_resize,
472
- img2img_width,
473
- img2img_height
474
- ],
475
- [img2img_image_editor, img2img_image_mask, img2img_btn_editor, img2img_btn_mask,
476
- img2img_painterro_btn, img2img_mask, img2img_mask_blur_strength]
477
- )
478
-
479
- # 这个函数之前注释掉了,但是没有出现bug,看一下这个的作用;
480
- img2img_image_editor_mode.change(
481
- uifn.update_image_mask,
482
- [img2img_image_editor, img2img_resize, img2img_width, img2img_height],
483
- img2img_image_mask
484
- )
485
 
486
- # 把上面这个注释掉下面就不管用了,很神奇无法理解... 后来又好了
487
- # output_txt2img_copy_to_input_btn.click(
488
- # uifn.copy_img_to_input,
489
- # [gallery],
490
- # [tabs, img2img_image_editor, img2img_image_mask],
491
- # _js=call_JS("moveImageFromGallery",
492
- # fromId="txt2img_gallery_output",
493
- # toId="img2img_mask")
494
- # )
495
  output_txt2img_copy_to_input_btn.click(
496
  uifn.copy_img_to_input,
497
  [gallery, img_choices],
498
- [tabs, img2img_image_editor, img2img_image_mask, move_prompt_zh, move_prompt_en, prompt]
499
  )
500
 
501
- # go_to_img2img_btn.click(
502
- # uifn.go_to_img2img,
503
- # [],
504
- # [tabs],
505
- # )
506
-
507
- # 下面这几个函数现在都没什么用
508
- output_img2img_copy_to_input_btn.click(
509
- uifn.copy_img_to_edit,
510
- [output_img2img_gallery],
511
- [img2img_image_editor, tabs, img2img_image_editor_mode],
512
- _js=call_JS("moveImageFromGallery",
513
- fromId="gallery",
514
- toId="img2img_editor")
515
- )
516
- output_img2img_copy_to_mask_btn.click(
517
- uifn.copy_img_to_mask,
518
- [output_img2img_gallery],
519
- [img2img_image_mask, tabs, img2img_image_editor_mode],
520
- _js=call_JS("moveImageFromGallery",
521
- fromId="img2img_gallery_output",
522
- toId="img2img_editor")
523
- )
524
-
525
- output_img2img_copy_to_clipboard_btn.click(fn=None, inputs=output_img2img_gallery, outputs=[],
526
- _js=call_JS("copyImageFromGalleryToClipboard",
527
- fromId="img2img_gallery_output")
528
- )
529
 
530
  img2img_func = img2img
531
- img2img_inputs = [img2img_prompt, img2img_image_editor_mode, img2img_mask,
532
- img2img_mask_blur_strength, img2img_steps, img2img_sampling, img2img_toggles,
533
- img2img_realesrgan_model_name, img2img_batch_count, img2img_cfg,
534
- img2img_denoising, img2img_seed, img2img_height, img2img_width, img2img_resize,
535
- img2img_image_mask]
536
- # img2img_outputs = [output_img2img_gallery, output_img2img_seed, output_img2img_params,
537
- # output_img2img_stats]
538
-
539
  img2img_outputs = [output_img2img_gallery]
540
 
541
- # If a JobManager was passed in then wrap the Generate functions
542
- if img2img_job_ui:
543
- img2img_func, img2img_inputs, img2img_outputs = img2img_job_ui.wrap_func(
544
- func=img2img_func,
545
- inputs=img2img_inputs,
546
- outputs=img2img_outputs,
547
- )
548
-
549
  img2img_btn_mask.click(
550
  img2img_func,
551
  img2img_inputs,
@@ -553,11 +296,6 @@ if __name__ == "__main__":
553
  )
554
 
555
  def img2img_submit_params():
556
- # print([img2img_prompt, img2img_image_editor_mode, img2img_mask,
557
- # img2img_mask_blur_strength, img2img_steps, img2img_sampling, img2img_toggles,
558
- # img2img_realesrgan_model_name, img2img_batch_count, img2img_cfg,
559
- # img2img_denoising, img2img_seed, img2img_height, img2img_width, img2img_resize,
560
- # img2img_image_editor, img2img_image_mask, img2img_embeddings])
561
  return (img2img_func,
562
  img2img_inputs,
563
  img2img_outputs)
@@ -569,25 +307,7 @@ if __name__ == "__main__":
569
  _js=call_JS("clickFirstVisibleButton",
570
  rowId="prompt_row"))
571
 
572
- img2img_painterro_btn.click(None,
573
- [img2img_image_editor, img2img_image_mask, img2img_image_editor_mode],
574
- [img2img_image_editor, img2img_image_mask],
575
- _js=call_JS("Painterro.init", toId="img2img_editor")
576
- )
577
-
578
- img2img_width.change(fn=uifn.update_dimensions_info, inputs=[img2img_width, img2img_height],
579
- outputs=img2img_dimensions_info_text_box)
580
- img2img_height.change(fn=uifn.update_dimensions_info, inputs=[img2img_width, img2img_height],
581
- outputs=img2img_dimensions_info_text_box)
582
-
583
- # share_button.click(
584
- # None,
585
- # [],
586
- # [],
587
- # _js=share_js,
588
- # )
589
-
590
  gr.HTML(read_content("footer.html"))
591
  # gr.Image('./contributors.png')
592
 
593
- block.queue(max_size=50, concurrency_count=20).launch()
 
 
 
1
  import io
2
+ import re
3
+ import imp
4
+ import time
5
+ import json
6
  import base64
7
  import requests
8
+ import gradio as gr
9
  import ui_functions as uifn
10
  from css_and_js import js, call_JS
11
+ from PIL import Image, PngImagePlugin, ImageChops
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
12
 
13
+ url_host = "http://flagstudio.baai.ac.cn"
14
+ token = "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJ1c2VyX2lkIjoiZjAxOGMxMzJiYTUyNDBjMzk5NTMzYTI5YjBmMzZiODMiLCJhcHBfbmFtZSI6IndlYiIsImlkZW50aXR5X3R5cGUiOiIyIiwidXNlcl9yb2xlIjoiMiIsImp0aSI6IjVjMmQzMjdiLWI5Y2MtNDhiZS1hZWQ4LTllMjQ4MDk4NzMxYyIsIm5iZiI6MTY2OTAwNjE5NywiZXhwIjoxOTg0MzY2MTk3LCJpYXQiOjE2NjkwMDYxOTd9.9B3MDk8wA6iWH5puXjcD19tJJ4Ox7mdpRyWZs5Kwt70"
 
 
 
 
 
 
 
 
 
 
 
 
15
 
16
  def read_content(file_path: str) -> str:
17
  """read the content of target file
 
21
 
22
  return content
23
 
 
 
 
 
 
 
 
24
  def filter_content(raw_style: str):
25
  if "(" in raw_style:
26
  i = raw_style.index("(")
 
32
  else :
33
  return raw_style[:i]
34
 
35
+ def upload_image(img):
36
+ url = url_host + "/api/v1/image/get-upload-link"
37
+ headers = {"token": token}
38
+ r = requests.post(url, json={}, headers=headers)
39
+ if r.status_code != 200:
40
+ raise gr.Error(r.reason)
41
+ head_res = r.json()
42
+ if head_res["code"] != 0:
43
+ raise gr.Error("Unknown error")
44
+ image_id = head_res["data"]["image_id"]
45
+ image_url = head_res["data"]["url"]
46
+ image_headers = head_res["data"]["headers"]
47
+
48
+ imgBytes = io.BytesIO()
49
+ img.save(imgBytes, "PNG")
50
+ imgBytes = imgBytes.getvalue()
51
+
52
+ r = requests.put(image_url, data=imgBytes, headers=image_headers)
53
+ if r.status_code != 200:
54
+ raise gr.Error(r.reason)
55
+ return image_id, image_url
56
+
57
+ def post_reqest(seed, prompt, width, height, image_num, img=None, mask=None):
58
+ data = {
59
+ "type": "gen-image",
60
+ "gen_image_num": image_num,
61
+ "parameters": {
62
+ "width": width, # output height width
63
+ "height": height, # output image height
64
+ "prompts": [prompt],
65
+ }
66
+ }
67
+ data["parameters"]["seed"] = int(seed)
68
+ if img is not None:
69
+ # Upload image
70
+ image_id, image_url = upload_image(img)
71
+ data["parameters"]["init_image"] = {
72
+ "image_id": image_id,
73
+ "url": image_url,
74
+ "width": img.width,
75
+ "height": img.height,
76
+ }
77
+ if mask is not None:
78
+ # Upload mask
79
+ extrama = mask.convert("L").getextrema()
80
+ if extrama[1] > 0:
81
+ mask_id, mask_url = upload_image(mask)
82
+ data["parameters"]["mask_image"] = {
83
+ "image_id": mask_id,
84
+ "url": mask_url,
85
+ "width": mask.width,
86
+ "height": mask.height,
87
+ }
88
+ headers = {"token": token}
89
+ # Send create task request
90
+ # url = "http://flagstudio.baai.ac.cn/api/v1/task/create"
91
+ url = url_host+"/api/v1/task/create"
92
+ r = requests.post(url, json=data, headers=headers)
93
+ if r.status_code != 200:
94
+ raise gr.Error(r.reason)
95
+ create_res = r.json()
96
+ task_id = create_res["data"]["task_id"]
97
+
98
+ # Get result
99
+ url = url_host+"/api/v1/task/status"
100
+ while True:
101
+ r = requests.post(url, json=create_res["data"], headers=headers)
102
+ if r.status_code != 200:
103
+ raise gr.Error(r.reason)
104
+ res = r.json()
105
+ if res["code"] == 6002:
106
+ # Running
107
+ time.sleep(1)
108
+ continue
109
+ elif res["code"] == 0:
110
+ # Finished
111
+ images = []
112
+ for img_info in res["data"]["images"]:
113
+ img_res = requests.get(img_info["url"])
114
+ images.append(Image.open(io.BytesIO(img_res.content)).convert("RGB"))
115
+ return images
116
+ else:
117
+ raise gr.Error(f"Error code: {res['code']}")
118
+
119
+ def request_images(raw_text, class_draw, style_draw, batch_size, w, h, seed):
120
  if filter_content(class_draw) != "国画":
121
  if filter_content(class_draw) != "通用":
122
  raw_text = raw_text + f",{filter_content(class_draw)}"
123
 
124
  for sty in style_draw:
125
  raw_text = raw_text + f",{filter_content(sty)}"
 
 
126
  elif filter_content(class_draw) == "国画":
127
+ raw_text = raw_text + ",国画,水墨画,大作,黑白,高清,传统"
128
+ print(f"raw text is {raw_text}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
129
 
130
+ images = post_reqest(seed, raw_text, w, h, int(batch_size))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
131
 
132
  return images
133
 
134
 
135
+ def img2img(prompt, image_and_mask):
136
+ return post_reqest(0, prompt, 512, 512, 1, image_and_mask["image"], image_and_mask["mask"])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
137
 
138
 
139
  examples = [
 
154
 
155
  with block:
156
  gr.HTML(read_content("header.html"))
 
 
 
 
157
  with gr.Tabs(elem_id='tabss') as tabs:
158
 
159
  with gr.TabItem("文生图(Text-to-img)", id='txt2img_tab'):
 
205
  # interactive=True,
206
  # )
207
  sample_size = gr.Radio(choices=["1","2","3","4"], value="1", show_label=True, label='生成数量(number)')
208
+ seed = gr.Number(0, label='seed', interactive=True)
209
  with gr.Row().style(mobile_collapse=False, equal_height=True):
210
  w = gr.Slider(512,1024,value=512, step=64, label="width")
211
  h = gr.Slider(512,1024,value=512, step=64, label="height")
 
 
212
 
213
  gallery = gr.Gallery(
214
  label="Generated images", show_label=False, elem_id="gallery"
215
  ).style(grid=[2,2])
216
  gr.Examples(examples=examples, fn=request_images, inputs=text, outputs=gallery, examples_per_page=100)
217
  with gr.Row().style(mobile_collapse=False, equal_height=True):
218
+ img_choices = gr.Dropdown(["图片1(img1)"],label='请选择一张图片发送到图生图',show_label=True,value="图片1(img1)")
219
  with gr.Row().style(mobile_collapse=False, equal_height=True):
220
  output_txt2img_copy_to_input_btn = gr.Button("发送图片到图生图(Sent the image to img2img)").style(
221
  margin=False,
222
  rounded=(True, True, True, True),
223
  )
 
 
 
 
 
 
 
 
224
 
225
  with gr.Row():
226
  prompt = gr.Markdown("提示(Prompt):", visible=False)
 
231
 
232
 
233
 
234
+ text.submit(request_images, inputs=[text, class_draw, style_draw, sample_size, w, h, seed], outputs=gallery)
235
+ btn.click(request_images, inputs=[text, class_draw, style_draw, sample_size, w, h, seed], outputs=gallery)
 
 
 
 
 
 
 
 
236
 
237
  sample_size.change(
238
  fn=uifn.change_img_choices,
 
246
  elem_id='img2img_prompt_input',
247
  placeholder="神奇的森林,流淌的河流.",
248
  lines=1,
249
+ max_lines=1,
250
+ value="",
251
  show_label=False).style()
252
 
253
  img2img_btn_mask = gr.Button("Generate", variant="primary", visible=False,
 
257
  with gr.Row().style(equal_height=False):
258
  #with gr.Column():
259
  img2img_image_mask = gr.Image(
260
+ value=None,
261
  source="upload",
262
  interactive=True,
263
  tool="sketch",
 
265
  elem_id="img2img_mask",
266
  image_mode="RGBA"
267
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
268
  gr.Markdown('#### 编辑后的图片')
269
  with gr.Row():
270
  output_img2img_gallery = gr.Gallery(label="Images", elem_id="img2img_gallery_output").style(
271
  grid=[4,4,4] )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
272
  with gr.Row():
273
  gr.Markdown('提示(prompt):')
274
  with gr.Row():
 
277
  gr.Markdown('Please select an image to cover up a part of the area and enter a text description.')
278
  gr.Markdown('# 编辑设置',visible=False)
279
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
280
 
 
 
 
 
 
 
 
 
 
281
  output_txt2img_copy_to_input_btn.click(
282
  uifn.copy_img_to_input,
283
  [gallery, img_choices],
284
+ [tabs, img2img_image_mask, move_prompt_zh, move_prompt_en, prompt]
285
  )
286
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
287
 
288
  img2img_func = img2img
289
+ img2img_inputs = [img2img_prompt, img2img_image_mask]
 
 
 
 
 
 
 
290
  img2img_outputs = [output_img2img_gallery]
291
 
 
 
 
 
 
 
 
 
292
  img2img_btn_mask.click(
293
  img2img_func,
294
  img2img_inputs,
 
296
  )
297
 
298
  def img2img_submit_params():
 
 
 
 
 
299
  return (img2img_func,
300
  img2img_inputs,
301
  img2img_outputs)
 
307
  _js=call_JS("clickFirstVisibleButton",
308
  rowId="prompt_row"))
309
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
310
  gr.HTML(read_content("footer.html"))
311
  # gr.Image('./contributors.png')
312
 
313
+ block.queue(max_size=50, concurrency_count=20).launch()
ui_functions.py CHANGED
@@ -96,16 +96,16 @@ def copy_img_to_input(img, idx):
96
  "图片4(img4)":3,
97
  }
98
  idx = idx_map[idx]
99
- image_data = re.sub('^data:image/.+;base64,', '', img[idx])
100
- processed_image = Image.open(BytesIO(base64.b64decode(image_data)))
101
  tab_update = gr.update(selected='img2img_tab')
102
- img_update = gr.update(value=processed_image)
103
  move_prompt_zh_update = gr.update(visible=True)
104
  move_prompt_en_update = gr.update(visible=True)
105
  prompt_update = gr.update(visible=True)
106
- return tab_update,processed_image, processed_image, move_prompt_zh_update, move_prompt_en_update, prompt_update
107
- except IndexError:
108
- return [None, None]
 
109
 
110
  def copy_img_to_edit(img):
111
  try:
 
96
  "图片4(img4)":3,
97
  }
98
  idx = idx_map[idx]
99
+ assert img[idx]['is_file']
100
+ processed_image = Image.open(img[idx]['name'])
101
  tab_update = gr.update(selected='img2img_tab')
 
102
  move_prompt_zh_update = gr.update(visible=True)
103
  move_prompt_en_update = gr.update(visible=True)
104
  prompt_update = gr.update(visible=True)
105
+ return tab_update, processed_image, move_prompt_zh_update, move_prompt_en_update, prompt_update
106
+ except IndexError as e:
107
+ raise gr.Error(e)
108
+ return [None, None, None, None, None]
109
 
110
  def copy_img_to_edit(img):
111
  try: