chaojiemao commited on
Commit
ec9288d
1 Parent(s): 78e9f55

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +288 -109
app.py CHANGED
@@ -1,45 +1,40 @@
1
  # -*- coding: utf-8 -*-
2
  # Copyright (c) Alibaba, Inc. and its affiliates.
3
- import os
4
- import shlex
5
- import subprocess
6
- subprocess.run(shlex.split('pip install flash-attn --no-build-isolation'), env=os.environ | {'FLASH_ATTENTION_SKIP_CUDA_BUILD': "TRUE"})
7
-
8
- import sys
9
- import csv
10
- csv.field_size_limit(sys.maxsize)
11
-
12
- import argparse
13
  import base64
14
  import copy
15
  import glob
16
  import io
17
- import os
18
  import random
19
  import re
 
20
  import string
 
21
  import threading
22
  import spaces
 
 
 
 
23
  import cv2
24
  import gradio as gr
25
  import numpy as np
26
  import torch
27
  import transformers
28
- from diffusers import CogVideoXImageToVideoPipeline
29
- from diffusers.utils import export_to_video
30
- from gradio_imageslider import ImageSlider
31
  from PIL import Image
32
  from transformers import AutoModel, AutoTokenizer
33
 
 
34
  from scepter.modules.utils.config import Config
35
  from scepter.modules.utils.directory import get_md5
36
  from scepter.modules.utils.file_system import FS
37
  from scepter.studio.utils.env import init_env
 
38
 
39
- from infer import ACEInference
40
  from example import get_examples
41
  from utils import load_image
42
 
 
43
 
44
  refresh_sty = '\U0001f504' # 🔄
45
  clear_sty = '\U0001f5d1' # 🗑️
@@ -53,33 +48,56 @@ lock = threading.Lock()
53
 
54
  class ChatBotUI(object):
55
  def __init__(self,
56
- cfg,
 
 
57
  root_work_dir='./'):
58
-
 
 
 
 
 
 
 
 
 
 
 
59
  cfg.WORK_DIR = os.path.join(root_work_dir, cfg.WORK_DIR)
60
  if not FS.exists(cfg.WORK_DIR):
61
  FS.make_dir(cfg.WORK_DIR)
62
  cfg = init_env(cfg)
63
  self.cache_dir = cfg.WORK_DIR
64
- self.chatbot_examples = get_examples(self.cache_dir)
65
  self.model_cfg_dir = cfg.MODEL.EDIT_MODEL.MODEL_CFG_DIR
66
  self.model_yamls = glob.glob(os.path.join(self.model_cfg_dir,
67
  '*.yaml'))
68
  self.model_choices = dict()
 
69
  for i in self.model_yamls:
70
- model_name = '.'.join(i.split('/')[-1].split('.')[:-1])
71
- self.model_choices[model_name] = i
72
- print('Models: ', self.model_choices)
73
-
74
- self.model_name = cfg.MODEL.EDIT_MODEL.DEFAULT
75
- assert self.model_name in self.model_choices
76
- model_cfg = Config(load=True,
77
- cfg_file=self.model_choices[self.model_name])
 
 
 
 
 
 
 
 
78
  self.pipe = ACEInference()
79
- self.pipe.init_from_cfg(model_cfg)
80
  self.max_msgs = 20
81
-
82
  self.enable_i2v = cfg.get('ENABLE_I2V', False)
 
 
83
  if self.enable_i2v:
84
  self.i2v_model_dir = cfg.MODEL.I2V.MODEL_DIR
85
  self.i2v_model_name = cfg.MODEL.I2V.MODEL_NAME
@@ -115,15 +133,11 @@ class ChatBotUI(object):
115
  )
116
 
117
  sys_prompt = """You are part of a team of bots that creates videos. You work with an assistant bot that will draw anything you say in square brackets.
118
-
119
  For example , outputting " a beautiful morning in the woods with the sun peaking through the trees " will trigger your partner bot to output an video of a forest morning , as described. You will be prompted by people looking to create detailed , amazing videos. The way to accomplish this is to take their short prompts and make them extremely detailed and descriptive.
120
  There are a few rules to follow:
121
-
122
  You will only ever output a single video description per user request.
123
-
124
  When modifications are requested , you should not simply make the description longer . You should refactor the entire description to integrate the suggestions.
125
  Other times the user will not want modifications , but instead want a new image . In this case , you should ignore your previous conversation with the user.
126
-
127
  Video descriptions must have the same num of words as examples below. Extra words will be ignored.
128
  """
129
  self.enhance_ctx = [
@@ -170,6 +184,7 @@ class ChatBotUI(object):
170
  ]
171
 
172
  def create_ui(self):
 
173
  css = '.chatbot.prose.md {opacity: 1.0 !important} #chatbot {opacity: 1.0 !important}'
174
  with gr.Blocks(css=css,
175
  title='Chatbot',
@@ -180,7 +195,8 @@ class ChatBotUI(object):
180
  self.history_result = gr.State(value={})
181
  self.retry_msg = gr.State(value='')
182
  with gr.Group():
183
- with gr.Row(equal_height=True):
 
184
  with gr.Column(visible=True) as self.chat_page:
185
  self.chatbot = gr.Chatbot(
186
  height=600,
@@ -195,7 +211,7 @@ class ChatBotUI(object):
195
  size='sm')
196
 
197
  with gr.Column(visible=False) as self.editor_page:
198
- with gr.Tabs():
199
  with gr.Tab(id='ImageUploader',
200
  label='Image Uploader',
201
  visible=True) as self.upload_tab:
@@ -204,7 +220,7 @@ class ChatBotUI(object):
204
  interactive=True,
205
  type='pil',
206
  image_mode='RGB',
207
- sources='upload',
208
  elem_id='image_uploader',
209
  format='png')
210
  with gr.Row():
@@ -212,10 +228,9 @@ class ChatBotUI(object):
212
  value='Submit',
213
  elem_id='upload_submit')
214
  self.ext_btn_1 = gr.Button(value='Exit')
215
-
216
  with gr.Tab(id='ImageEditor',
217
- label='Image Editor',
218
- visible=False) as self.edit_tab:
219
  self.mask_type = gr.Dropdown(
220
  label='Mask Type',
221
  choices=[
@@ -278,13 +293,23 @@ class ChatBotUI(object):
278
  self.ext_btn_2 = gr.Button(value='Exit')
279
 
280
  with gr.Tab(id='ImageViewer',
281
- label='Image Viewer',
282
- visible=False) as self.image_view_tab:
283
- self.image_viewer = ImageSlider(
284
- label='Image',
285
- type='pil',
286
- show_download_button=True,
287
- elem_id='image_viewer')
 
 
 
 
 
 
 
 
 
 
288
 
289
  self.ext_btn_3 = gr.Button(value='Exit')
290
 
@@ -303,11 +328,30 @@ class ChatBotUI(object):
303
 
304
  self.ext_btn_4 = gr.Button(value='Exit')
305
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
306
  with gr.Accordion(label='Setting', open=False):
307
  with gr.Row():
308
  self.model_name_dd = gr.Dropdown(
309
  choices=self.model_choices,
310
- value=self.model_name,
311
  label='Model Version')
312
 
313
  with gr.Row():
@@ -318,39 +362,63 @@ class ChatBotUI(object):
318
  label='Negative Prompt',
319
  container=False)
320
 
 
 
 
 
 
 
 
 
 
 
 
321
  with gr.Row():
322
  with gr.Column(scale=8, min_width=500):
323
  with gr.Row():
324
  self.step = gr.Slider(minimum=1,
325
  maximum=1000,
326
- value=20,
 
327
  label='Sample Step')
328
  self.cfg_scale = gr.Slider(
329
  minimum=1.0,
330
  maximum=20.0,
331
- value=4.5,
 
332
  label='Guidance Scale')
333
  self.rescale = gr.Slider(minimum=0.0,
334
  maximum=1.0,
335
- value=0.5,
 
336
  label='Rescale')
 
 
 
 
 
337
  self.seed = gr.Slider(minimum=-1,
338
  maximum=10000000,
339
  value=-1,
340
  label='Seed')
341
  self.output_height = gr.Slider(
342
  minimum=256,
343
- maximum=1024,
344
- value=512,
 
345
  label='Output Height')
346
  self.output_width = gr.Slider(
347
  minimum=256,
348
- maximum=1024,
349
- value=512,
 
350
  label='Output Width')
351
  with gr.Column(scale=1, min_width=50):
352
  self.use_history = gr.Checkbox(value=False,
353
  label='Use History')
 
 
 
354
  self.video_auto = gr.Checkbox(
355
  value=False,
356
  label='Auto Gen Video',
@@ -387,9 +455,8 @@ class ChatBotUI(object):
387
  visible=True)
388
 
389
  with gr.Row():
390
- inst = """
391
  **Instruction**:
392
-
393
  1. Click 'Upload' button to upload one or more images as input images.
394
  2. Enter '@' in the text box will exhibit all images in the gallery.
395
  3. Select the image you wish to edit from the gallery, and its Image ID will be displayed in the text box.
@@ -399,14 +466,24 @@ class ChatBotUI(object):
399
  6. **Important** To render text on an image, please ensure to include a space between each letter. For instance, "add text 'g i r l' on the mask area of @xxxxx".
400
  7. To implement local editing based on a specified mask, simply click on the image within the chat window to access the image editor. Here, you can draw a mask and then click the 'Submit' button to upload the edited image along with the mask. For inpainting tasks, select the 'Composite' mask type, while for outpainting tasks, choose the 'Outpainting' mask type. For all other local editing tasks, please select the 'Background' mask type.
401
  8. If you find our work valuable, we invite you to refer to the [ACE Page](https://ali-vilab.github.io/ace-page/) for comprehensive information.
402
-
403
  """
404
- gr.Markdown(value=inst)
405
-
 
 
 
 
 
 
 
 
 
 
 
406
  with gr.Row(variant='panel',
407
  equal_height=True,
408
  show_progress=False):
409
- with gr.Column(scale=1, min_width=100):
410
  self.upload_btn = gr.Button(value=upload_sty +
411
  ' Upload',
412
  variant='secondary')
@@ -416,12 +493,16 @@ class ChatBotUI(object):
416
  label='Instruction',
417
  container=False)
418
  with gr.Column(scale=1, min_width=100):
419
- self.chat_btn = gr.Button(value=chat_sty + ' Chat',
420
  variant='primary')
421
  with gr.Column(scale=1, min_width=100):
422
  self.retry_btn = gr.Button(value=refresh_sty +
423
  ' Retry',
424
  variant='secondary')
 
 
 
 
425
  with gr.Column(scale=(1 if self.enable_i2v else 0),
426
  min_width=0):
427
  self.video_gen_btn = gr.Button(value=video_sty +
@@ -457,19 +538,77 @@ class ChatBotUI(object):
457
  lock.acquire()
458
  del self.pipe
459
  torch.cuda.empty_cache()
460
- model_cfg = Config(load=True,
461
- cfg_file=self.model_choices[model_name])
462
  self.pipe = ACEInference()
463
- self.pipe.init_from_cfg(model_cfg)
464
  self.model_name = model_name
465
  lock.release()
466
 
467
- return model_name, gr.update(), gr.update()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
468
 
469
  self.model_name_dd.change(
470
  change_model,
471
  inputs=[self.model_name_dd],
472
- outputs=[self.model_name_dd, self.chatbot, self.text])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
473
 
474
  ########################################
475
  def generate_gallery(text, images):
@@ -516,7 +655,6 @@ class ChatBotUI(object):
516
  outputs=[self.text, self.gallery])
517
 
518
  ########################################
519
- @spaces.GPU(duration=120)
520
  def generate_video(message,
521
  extend_prompt,
522
  history,
@@ -527,6 +665,9 @@ class ChatBotUI(object):
527
  fps,
528
  seed,
529
  progress=gr.Progress(track_tqdm=True)):
 
 
 
530
  generator = torch.Generator(device='cuda').manual_seed(seed)
531
  img_ids = re.findall('@(.*?)[ ,;.?$]', message)
532
  if len(img_ids) == 0:
@@ -597,8 +738,12 @@ class ChatBotUI(object):
597
  outputs=[self.history, self.chatbot, self.text, self.gallery])
598
 
599
  ########################################
600
- @spaces.GPU(duration=60)
601
- def run_chat(message,
 
 
 
 
602
  extend_prompt,
603
  history,
604
  images,
@@ -607,6 +752,8 @@ class ChatBotUI(object):
607
  negative_prompt,
608
  cfg_scale,
609
  rescale,
 
 
610
  step,
611
  seed,
612
  output_h,
@@ -618,12 +765,25 @@ class ChatBotUI(object):
618
  video_fps,
619
  video_seed,
620
  progress=gr.Progress(track_tqdm=True)):
 
 
 
 
 
 
621
  retry_msg = message
622
  gen_id = get_md5(message)[:12]
623
  save_path = os.path.join(self.cache_dir, f'{gen_id}.png')
624
 
625
  img_ids = re.findall('@(.*?)[ ,;.?$]', message)
626
  history_io = None
 
 
 
 
 
 
 
627
  new_message = message
628
 
629
  if len(img_ids) > 0:
@@ -655,9 +815,9 @@ class ChatBotUI(object):
655
  history_io = history_result[img_id]
656
 
657
  buffered = io.BytesIO()
658
- edit_image[0].save(buffered, format='JPEG')
659
  img_b64 = base64.b64encode(buffered.getvalue()).decode('utf-8')
660
- img_str = f'<img src="data:image/jpg;base64,{img_b64}" style="pointer-events: none;">'
661
  pre_info = f'Received one or more images, so image editing is conducted.\n The first input image @{img_ids[0]} is:\n {img_str}'
662
  else:
663
  pre_info = 'No image ids were found in the provided text prompt, so text-guided image generation is conducted. \n'
@@ -682,6 +842,9 @@ class ChatBotUI(object):
682
  guide_scale=cfg_scale,
683
  guide_rescale=rescale,
684
  seed=seed,
 
 
 
685
  )
686
 
687
  img = imgs[0]
@@ -728,9 +891,9 @@ class ChatBotUI(object):
728
  }
729
 
730
  buffered = io.BytesIO()
731
- img.convert('RGB').save(buffered, format='JPEG')
732
  img_b64 = base64.b64encode(buffered.getvalue()).decode('utf-8')
733
- img_str = f'<img src="data:image/jpg;base64,{img_b64}" style="pointer-events: none;">'
734
 
735
  history.append(
736
  (message,
@@ -790,21 +953,25 @@ class ChatBotUI(object):
790
  while len(history) >= self.max_msgs:
791
  history.pop(0)
792
 
793
- return history, images, history_result, self.get_history(
794
- history), gr.update(value=''), gr.update(
795
- visible=False), retry_msg
 
796
 
797
  chat_inputs = [
 
798
  self.extend_prompt, self.history, self.images, self.use_history,
799
  self.history_result, self.negative_prompt, self.cfg_scale,
800
- self.rescale, self.step, self.seed, self.output_height,
 
801
  self.output_width, self.video_auto, self.video_step,
802
  self.video_frames, self.video_cfg_scale, self.video_fps,
803
  self.video_seed
804
  ]
805
 
806
  chat_outputs = [
807
- self.history, self.images, self.history_result, self.chatbot,
 
808
  self.text, self.gallery, self.retry_msg
809
  ]
810
 
@@ -824,7 +991,7 @@ class ChatBotUI(object):
824
  outputs=chat_outputs)
825
 
826
  ########################################
827
- @spaces.GPU(duration=60)
828
  def run_example(task, img, img_mask, ref1, prompt, seed):
829
  edit_image, edit_image_mask, edit_task = [], [], []
830
  if img is not None:
@@ -848,9 +1015,9 @@ class ChatBotUI(object):
848
  edit_task.append('')
849
 
850
  buffered = io.BytesIO()
851
- img.save(buffered, format='JPEG')
852
  img_b64 = base64.b64encode(buffered.getvalue()).decode('utf-8')
853
- img_str = f'<img src="data:image/jpg;base64,{img_b64}" style="pointer-events: none;">'
854
  pre_info = f'Received one or more images, so image editing is conducted.\n The first input image is:\n {img_str}'
855
  else:
856
  pre_info = 'No image ids were found in the provided text prompt, so text-guided image generation is conducted. \n'
@@ -866,13 +1033,15 @@ class ChatBotUI(object):
866
  prompt=[prompt] * img_num,
867
  negative_prompt=[''] * img_num,
868
  seed=seed,
 
 
869
  )
870
 
871
  img = imgs[0]
872
  buffered = io.BytesIO()
873
- img.convert('RGB').save(buffered, format='JPEG')
874
  img_b64 = base64.b64encode(buffered.getvalue()).decode('utf-8')
875
- img_str = f'<img src="data:image/jpg;base64,{img_b64}" style="pointer-events: none;">'
876
  history = [(prompt,
877
  f'{pre_info} The generated image is:\n {img_str}')]
878
  return self.get_history(history), gr.update(value=''), gr.update(
@@ -911,21 +1080,23 @@ class ChatBotUI(object):
911
  return (gr.update(visible=True,
912
  scale=1), gr.update(visible=True, scale=1),
913
  gr.update(visible=True), gr.update(visible=False),
914
- gr.update(visible=False), gr.update(visible=False))
 
915
 
916
  self.upload_btn.click(upload_image,
917
  inputs=[],
918
  outputs=[
919
  self.chat_page, self.editor_page,
920
  self.upload_tab, self.edit_tab,
921
- self.image_view_tab, self.video_view_tab
 
922
  ])
923
 
924
  ########################################
925
  def edit_image(evt: gr.SelectData):
926
  if isinstance(evt.value, str):
927
  img_b64s = re.findall(
928
- '<img src="data:image/jpg;base64,(.*?)" style="pointer-events: none;">',
929
  evt.value)
930
  imgs = [
931
  Image.open(io.BytesIO(base64.b64decode(copy.deepcopy(i))))
@@ -933,13 +1104,19 @@ class ChatBotUI(object):
933
  ]
934
  if len(imgs) > 0:
935
  if len(imgs) == 2:
936
- view_img = copy.deepcopy(imgs)
 
 
 
937
  edit_img = copy.deepcopy(imgs[-1])
938
  else:
939
- view_img = [
940
- copy.deepcopy(imgs[-1]),
941
- copy.deepcopy(imgs[-1])
942
- ]
 
 
 
943
  edit_img = copy.deepcopy(imgs[-1])
944
 
945
  return (gr.update(visible=True,
@@ -948,11 +1125,12 @@ class ChatBotUI(object):
948
  gr.update(visible=False), gr.update(visible=True),
949
  gr.update(visible=True), gr.update(visible=False),
950
  gr.update(value=edit_img),
951
- gr.update(value=view_img), gr.update(value=None))
 
952
  else:
953
  return (gr.update(), gr.update(), gr.update(), gr.update(),
954
  gr.update(), gr.update(), gr.update(), gr.update(),
955
- gr.update())
956
  elif isinstance(evt.value, dict) and evt.value.get(
957
  'component', '') == 'video':
958
  value = evt.value['value']['video']['path']
@@ -960,11 +1138,12 @@ class ChatBotUI(object):
960
  scale=1), gr.update(visible=True, scale=1),
961
  gr.update(visible=False), gr.update(visible=False),
962
  gr.update(visible=False), gr.update(visible=True),
963
- gr.update(), gr.update(), gr.update(value=value))
 
964
  else:
965
  return (gr.update(), gr.update(), gr.update(), gr.update(),
966
  gr.update(), gr.update(), gr.update(), gr.update(),
967
- gr.update())
968
 
969
  self.chatbot.select(edit_image,
970
  outputs=[
@@ -972,16 +1151,17 @@ class ChatBotUI(object):
972
  self.upload_tab, self.edit_tab,
973
  self.image_view_tab, self.video_view_tab,
974
  self.image_editor, self.image_viewer,
975
- self.video_viewer
976
  ])
977
 
978
- self.image_viewer.change(lambda x: x,
979
- inputs=self.image_viewer,
980
- outputs=self.image_viewer)
 
981
 
982
  ########################################
983
  def submit_upload_image(image, history, images):
984
- history, images = self.add_uploaded_image_to_history(
985
  image, history, images)
986
  return gr.update(visible=False), gr.update(
987
  visible=True), gr.update(
@@ -1151,14 +1331,14 @@ class ChatBotUI(object):
1151
  thumbnail.save(thumbnail_path, format='JPEG')
1152
 
1153
  buffered = io.BytesIO()
1154
- img.convert('RGB').save(buffered, format='JPEG')
1155
  img_b64 = base64.b64encode(buffered.getvalue()).decode('utf-8')
1156
- img_str = f'<img src="data:image/jpg;base64,{img_b64}" style="pointer-events: none;">'
1157
 
1158
  buffered = io.BytesIO()
1159
- mask.convert('RGB').save(buffered, format='JPEG')
1160
  mask_b64 = base64.b64encode(buffered.getvalue()).decode('utf-8')
1161
- mask_str = f'<img src="data:image/jpg;base64,{mask_b64}" style="pointer-events: none;">'
1162
 
1163
  images[img_id] = {
1164
  'image': save_path,
@@ -1207,22 +1387,21 @@ class ChatBotUI(object):
1207
  }
1208
 
1209
  buffered = io.BytesIO()
1210
- img.convert('RGB').save(buffered, format='JPEG')
1211
  img_b64 = base64.b64encode(buffered.getvalue()).decode('utf-8')
1212
- img_str = f'<img src="data:image/jpg;base64,{img_b64}" style="pointer-events: none;">'
1213
 
1214
  history.append(
1215
  (None,
1216
  f'This is uploaded image:\n {img_str} image ID is: {img_id}'))
1217
- return history, images
1218
 
1219
 
1220
 
1221
  if __name__ == '__main__':
1222
-
1223
- cfg = Config(cfg_file="config/chatbot_ui.yaml")
1224
  with gr.Blocks() as demo:
1225
  chatbot = ChatBotUI(cfg)
1226
  chatbot.create_ui()
1227
  chatbot.set_callbacks()
1228
- demo.launch()
 
1
  # -*- coding: utf-8 -*-
2
  # Copyright (c) Alibaba, Inc. and its affiliates.
 
 
 
 
 
 
 
 
 
 
3
  import base64
4
  import copy
5
  import glob
6
  import io
7
+ import os, csv, sys
8
  import random
9
  import re
10
+ import shlex
11
  import string
12
+ import subprocess
13
  import threading
14
  import spaces
15
+
16
+ subprocess.run(shlex.split('pip install flash-attn --no-build-isolation'),
17
+ env=os.environ | {'FLASH_ATTENTION_SKIP_CUDA_BUILD': "TRUE"})
18
+
19
  import cv2
20
  import gradio as gr
21
  import numpy as np
22
  import torch
23
  import transformers
 
 
 
24
  from PIL import Image
25
  from transformers import AutoModel, AutoTokenizer
26
 
27
+ from scepter.modules.inference.ace_inference import ACEInference
28
  from scepter.modules.utils.config import Config
29
  from scepter.modules.utils.directory import get_md5
30
  from scepter.modules.utils.file_system import FS
31
  from scepter.studio.utils.env import init_env
32
+ from importlib.metadata import version
33
 
 
34
  from example import get_examples
35
  from utils import load_image
36
 
37
+ csv.field_size_limit(sys.maxsize)
38
 
39
  refresh_sty = '\U0001f504' # 🔄
40
  clear_sty = '\U0001f5d1' # 🗑️
 
48
 
49
  class ChatBotUI(object):
50
  def __init__(self,
51
+ cfg_general_file,
52
+ is_debug=False,
53
+ language='en',
54
  root_work_dir='./'):
55
+ try:
56
+ from diffusers import CogVideoXImageToVideoPipeline
57
+ from diffusers.utils import export_to_video
58
+ except Exception as e:
59
+ print(f"Import diffusers failed, please install or upgrade diffusers. Error information: {e}")
60
+
61
+ cfg = Config(cfg_file=cfg_general_file)
62
+ if cfg.have("FILE_SYSTEM"):
63
+ for file_sys in cfg.FILE_SYSTEM:
64
+ fs_prefix = FS.init_fs_client(file_sys)
65
+ else:
66
+ fs_prefix = FS.init_fs_client(cfg)
67
  cfg.WORK_DIR = os.path.join(root_work_dir, cfg.WORK_DIR)
68
  if not FS.exists(cfg.WORK_DIR):
69
  FS.make_dir(cfg.WORK_DIR)
70
  cfg = init_env(cfg)
71
  self.cache_dir = cfg.WORK_DIR
72
+ self.chatbot_examples = get_examples(self.cache_dir) if not cfg.get('SKIP_EXAMPLES', False) else []
73
  self.model_cfg_dir = cfg.MODEL.EDIT_MODEL.MODEL_CFG_DIR
74
  self.model_yamls = glob.glob(os.path.join(self.model_cfg_dir,
75
  '*.yaml'))
76
  self.model_choices = dict()
77
+ self.default_model_name = ''
78
  for i in self.model_yamls:
79
+ model_cfg = Config(load=True, cfg_file=i)
80
+ model_name = model_cfg.NAME
81
+ if model_cfg.IS_DEFAULT: self.default_model_name = model_name
82
+ self.model_choices[model_name] = model_cfg
83
+ print('Models: ', self.model_choices.keys())
84
+
85
+ #FS.get_from("ms://AI-ModelScope/FLUX.1-dev@flux1-dev.safetensors")
86
+ #FS.get_from("ms://AI-ModelScope/FLUX.1-dev@ae.safetensors")
87
+ #FS.get_dir_to_local_dir("ms://AI-ModelScope/FLUX.1-dev@text_encoder_2/")
88
+ #FS.get_dir_to_local_dir("ms://AI-ModelScope/FLUX.1-dev@tokenizer_2/")
89
+ #FS.get_dir_to_local_dir("ms://AI-ModelScope/FLUX.1-dev@text_encoder/")
90
+ #FS.get_dir_to_local_dir("ms://AI-ModelScope/FLUX.1-dev@tokenizer/")
91
+
92
+ assert len(self.model_choices) > 0
93
+ if self.default_model_name == "": self.default_model_name = self.model_choices.keys()[0]
94
+ self.model_name = self.default_model_name
95
  self.pipe = ACEInference()
96
+ self.pipe.init_from_cfg(self.model_choices[self.default_model_name])
97
  self.max_msgs = 20
 
98
  self.enable_i2v = cfg.get('ENABLE_I2V', False)
99
+ self.gradio_version = version('gradio')
100
+
101
  if self.enable_i2v:
102
  self.i2v_model_dir = cfg.MODEL.I2V.MODEL_DIR
103
  self.i2v_model_name = cfg.MODEL.I2V.MODEL_NAME
 
133
  )
134
 
135
  sys_prompt = """You are part of a team of bots that creates videos. You work with an assistant bot that will draw anything you say in square brackets.
 
136
  For example , outputting " a beautiful morning in the woods with the sun peaking through the trees " will trigger your partner bot to output an video of a forest morning , as described. You will be prompted by people looking to create detailed , amazing videos. The way to accomplish this is to take their short prompts and make them extremely detailed and descriptive.
137
  There are a few rules to follow:
 
138
  You will only ever output a single video description per user request.
 
139
  When modifications are requested , you should not simply make the description longer . You should refactor the entire description to integrate the suggestions.
140
  Other times the user will not want modifications , but instead want a new image . In this case , you should ignore your previous conversation with the user.
 
141
  Video descriptions must have the same num of words as examples below. Extra words will be ignored.
142
  """
143
  self.enhance_ctx = [
 
184
  ]
185
 
186
  def create_ui(self):
187
+
188
  css = '.chatbot.prose.md {opacity: 1.0 !important} #chatbot {opacity: 1.0 !important}'
189
  with gr.Blocks(css=css,
190
  title='Chatbot',
 
195
  self.history_result = gr.State(value={})
196
  self.retry_msg = gr.State(value='')
197
  with gr.Group():
198
+ self.ui_mode = gr.State(value='legacy')
199
+ with gr.Row(equal_height=True, visible=False) as self.chat_group:
200
  with gr.Column(visible=True) as self.chat_page:
201
  self.chatbot = gr.Chatbot(
202
  height=600,
 
211
  size='sm')
212
 
213
  with gr.Column(visible=False) as self.editor_page:
214
+ with gr.Tabs(visible=False) as self.upload_tabs:
215
  with gr.Tab(id='ImageUploader',
216
  label='Image Uploader',
217
  visible=True) as self.upload_tab:
 
220
  interactive=True,
221
  type='pil',
222
  image_mode='RGB',
223
+ sources=['upload'],
224
  elem_id='image_uploader',
225
  format='png')
226
  with gr.Row():
 
228
  value='Submit',
229
  elem_id='upload_submit')
230
  self.ext_btn_1 = gr.Button(value='Exit')
231
+ with gr.Tabs(visible=False) as self.edit_tabs:
232
  with gr.Tab(id='ImageEditor',
233
+ label='Image Editor') as self.edit_tab:
 
234
  self.mask_type = gr.Dropdown(
235
  label='Mask Type',
236
  choices=[
 
293
  self.ext_btn_2 = gr.Button(value='Exit')
294
 
295
  with gr.Tab(id='ImageViewer',
296
+ label='Image Viewer') as self.image_view_tab:
297
+ if self.gradio_version >= '5.0.0':
298
+ self.image_viewer = gr.Image(
299
+ label='Image',
300
+ type='pil',
301
+ show_download_button=True,
302
+ elem_id='image_viewer')
303
+ else:
304
+ try:
305
+ from gradio_imageslider import ImageSlider
306
+ except Exception as e:
307
+ print(f"Import gradio_imageslider failed, please install.")
308
+ self.image_viewer = ImageSlider(
309
+ label='Image',
310
+ type='pil',
311
+ show_download_button=True,
312
+ elem_id='image_viewer')
313
 
314
  self.ext_btn_3 = gr.Button(value='Exit')
315
 
 
328
 
329
  self.ext_btn_4 = gr.Button(value='Exit')
330
 
331
+ with gr.Row(equal_height=True, visible=True) as self.legacy_group:
332
+ with gr.Column():
333
+ self.legacy_image_uploader = gr.Image(
334
+ height=550,
335
+ interactive=True,
336
+ type='pil',
337
+ image_mode='RGB',
338
+ elem_id='legacy_image_uploader',
339
+ format='png')
340
+ with gr.Column():
341
+ self.legacy_image_viewer = gr.Image(
342
+ label='Image',
343
+ height=550,
344
+ type='pil',
345
+ interactive=False,
346
+ show_download_button=True,
347
+ elem_id='image_viewer')
348
+
349
+
350
  with gr.Accordion(label='Setting', open=False):
351
  with gr.Row():
352
  self.model_name_dd = gr.Dropdown(
353
  choices=self.model_choices,
354
+ value=self.default_model_name,
355
  label='Model Version')
356
 
357
  with gr.Row():
 
362
  label='Negative Prompt',
363
  container=False)
364
 
365
+ with gr.Row():
366
+ # REFINER_PROMPT
367
+ self.refiner_prompt = gr.Textbox(
368
+ value=self.pipe.input.get("refiner_prompt", ""),
369
+ visible=self.pipe.input.get("refiner_prompt", None) is not None,
370
+ placeholder=
371
+ 'Prompt used for refiner',
372
+ label='Refiner Prompt',
373
+ container=False)
374
+
375
+
376
  with gr.Row():
377
  with gr.Column(scale=8, min_width=500):
378
  with gr.Row():
379
  self.step = gr.Slider(minimum=1,
380
  maximum=1000,
381
+ value=self.pipe.input.get("sample_steps", 20),
382
+ visible=self.pipe.input.get("sample_steps", None) is not None,
383
  label='Sample Step')
384
  self.cfg_scale = gr.Slider(
385
  minimum=1.0,
386
  maximum=20.0,
387
+ value=self.pipe.input.get("guide_scale", 4.5),
388
+ visible=self.pipe.input.get("guide_scale", None) is not None,
389
  label='Guidance Scale')
390
  self.rescale = gr.Slider(minimum=0.0,
391
  maximum=1.0,
392
+ value=self.pipe.input.get("guide_rescale", 0.5),
393
+ visible=self.pipe.input.get("guide_rescale", None) is not None,
394
  label='Rescale')
395
+ self.refiner_scale = gr.Slider(minimum=-0.1,
396
+ maximum=1.0,
397
+ value=self.pipe.input.get("refiner_scale", 0.5),
398
+ visible=self.pipe.input.get("refiner_scale", None) is not None,
399
+ label='Refiner Scale')
400
  self.seed = gr.Slider(minimum=-1,
401
  maximum=10000000,
402
  value=-1,
403
  label='Seed')
404
  self.output_height = gr.Slider(
405
  minimum=256,
406
+ maximum=1440,
407
+ value=self.pipe.input.get("output_height", 1024),
408
+ visible=self.pipe.input.get("output_height", None) is not None,
409
  label='Output Height')
410
  self.output_width = gr.Slider(
411
  minimum=256,
412
+ maximum=1440,
413
+ value=self.pipe.input.get("output_width", 1024),
414
+ visible=self.pipe.input.get("output_width", None) is not None,
415
  label='Output Width')
416
  with gr.Column(scale=1, min_width=50):
417
  self.use_history = gr.Checkbox(value=False,
418
  label='Use History')
419
+ self.use_ace = gr.Checkbox(value=self.pipe.input.get("use_ace", True),
420
+ visible=self.pipe.input.get("use_ace", None) is not None,
421
+ label='Use ACE')
422
  self.video_auto = gr.Checkbox(
423
  value=False,
424
  label='Auto Gen Video',
 
455
  visible=True)
456
 
457
  with gr.Row():
458
+ self.chatbot_inst = """
459
  **Instruction**:
 
460
  1. Click 'Upload' button to upload one or more images as input images.
461
  2. Enter '@' in the text box will exhibit all images in the gallery.
462
  3. Select the image you wish to edit from the gallery, and its Image ID will be displayed in the text box.
 
466
  6. **Important** To render text on an image, please ensure to include a space between each letter. For instance, "add text 'g i r l' on the mask area of @xxxxx".
467
  7. To implement local editing based on a specified mask, simply click on the image within the chat window to access the image editor. Here, you can draw a mask and then click the 'Submit' button to upload the edited image along with the mask. For inpainting tasks, select the 'Composite' mask type, while for outpainting tasks, choose the 'Outpainting' mask type. For all other local editing tasks, please select the 'Background' mask type.
468
  8. If you find our work valuable, we invite you to refer to the [ACE Page](https://ali-vilab.github.io/ace-page/) for comprehensive information.
 
469
  """
470
+
471
+ self.legacy_inst = """
472
+ **Instruction**:
473
+ 1. You can edit the image by uploading it; if no image is uploaded, an image will be generated from text..
474
+ 2. Enter '@' in the text box will exhibit all images in the gallery.
475
+ 3. Select the image you wish to edit from the gallery, and its Image ID will be displayed in the text box.
476
+ 4. **Important** To render text on an image, please ensure to include a space between each letter. For instance, "add text 'g i r l' on the mask area of @xxxxx".
477
+ 5. To perform multi-step editing, partial editing, inpainting, outpainting, and other operations, please click the Chatbot Checkbox to enable the conversational editing mode and follow the relevant instructions..
478
+ 6. If you find our work valuable, we invite you to refer to the [ACE Page](https://ali-vilab.github.io/ace-page/) for comprehensive information.
479
+ """
480
+
481
+ self.instruction = gr.Markdown(value=self.legacy_inst)
482
+
483
  with gr.Row(variant='panel',
484
  equal_height=True,
485
  show_progress=False):
486
+ with gr.Column(scale=1, min_width=100, visible=False) as self.upload_panel:
487
  self.upload_btn = gr.Button(value=upload_sty +
488
  ' Upload',
489
  variant='secondary')
 
493
  label='Instruction',
494
  container=False)
495
  with gr.Column(scale=1, min_width=100):
496
+ self.chat_btn = gr.Button(value='Generate',
497
  variant='primary')
498
  with gr.Column(scale=1, min_width=100):
499
  self.retry_btn = gr.Button(value=refresh_sty +
500
  ' Retry',
501
  variant='secondary')
502
+ with gr.Column(scale=1, min_width=100):
503
+ self.mode_checkbox = gr.Checkbox(
504
+ value=False,
505
+ label='ChatBot')
506
  with gr.Column(scale=(1 if self.enable_i2v else 0),
507
  min_width=0):
508
  self.video_gen_btn = gr.Button(value=video_sty +
 
538
  lock.acquire()
539
  del self.pipe
540
  torch.cuda.empty_cache()
 
 
541
  self.pipe = ACEInference()
542
+ self.pipe.init_from_cfg(self.model_choices[model_name])
543
  self.model_name = model_name
544
  lock.release()
545
 
546
+ return (model_name, gr.update(), gr.update(),
547
+ gr.Slider(
548
+ value=self.pipe.input.get("sample_steps", 20),
549
+ visible=self.pipe.input.get("sample_steps", None) is not None),
550
+ gr.Slider(
551
+ value=self.pipe.input.get("guide_scale", 4.5),
552
+ visible=self.pipe.input.get("guide_scale", None) is not None),
553
+ gr.Slider(
554
+ value=self.pipe.input.get("guide_rescale", 0.5),
555
+ visible=self.pipe.input.get("guide_rescale", None) is not None),
556
+ gr.Slider(
557
+ value=self.pipe.input.get("output_height", 1024),
558
+ visible=self.pipe.input.get("output_height", None) is not None),
559
+ gr.Slider(
560
+ value=self.pipe.input.get("output_width", 1024),
561
+ visible=self.pipe.input.get("output_width", None) is not None),
562
+ gr.Textbox(
563
+ value=self.pipe.input.get("refiner_prompt", ""),
564
+ visible=self.pipe.input.get("refiner_prompt", None) is not None),
565
+ gr.Slider(
566
+ value=self.pipe.input.get("refiner_scale", 0.5),
567
+ visible=self.pipe.input.get("refiner_scale", None) is not None
568
+ ),
569
+ gr.Checkbox(
570
+ value=self.pipe.input.get("use_ace", True),
571
+ visible=self.pipe.input.get("use_ace", None) is not None
572
+ )
573
+ )
574
 
575
  self.model_name_dd.change(
576
  change_model,
577
  inputs=[self.model_name_dd],
578
+ outputs=[
579
+ self.model_name_dd, self.chatbot, self.text,
580
+ self.step,
581
+ self.cfg_scale, self.rescale, self.output_height,
582
+ self.output_width, self.refiner_prompt, self.refiner_scale,
583
+ self.use_ace])
584
+
585
+
586
+ def mode_change(mode_check):
587
+ if mode_check:
588
+ # ChatBot
589
+ return (
590
+ gr.Row(visible=False),
591
+ gr.Row(visible=True),
592
+ gr.Button(value='Generate'),
593
+ gr.State(value='chatbot'),
594
+ gr.Column(visible=True),
595
+ gr.Markdown(value=self.chatbot_inst)
596
+ )
597
+ else:
598
+ # Legacy
599
+ return (
600
+ gr.Row(visible=True),
601
+ gr.Row(visible=False),
602
+ gr.Button(value=chat_sty + ' Chat'),
603
+ gr.State(value='legacy'),
604
+ gr.Column(visible=False),
605
+ gr.Markdown(value=self.legacy_inst)
606
+ )
607
+ self.mode_checkbox.change(mode_change, inputs=[self.mode_checkbox],
608
+ outputs=[self.legacy_group, self.chat_group,
609
+ self.chat_btn, self.ui_mode,
610
+ self.upload_panel, self.instruction])
611
+
612
 
613
  ########################################
614
  def generate_gallery(text, images):
 
655
  outputs=[self.text, self.gallery])
656
 
657
  ########################################
 
658
  def generate_video(message,
659
  extend_prompt,
660
  history,
 
665
  fps,
666
  seed,
667
  progress=gr.Progress(track_tqdm=True)):
668
+
669
+ from diffusers.utils import export_to_video
670
+
671
  generator = torch.Generator(device='cuda').manual_seed(seed)
672
  img_ids = re.findall('@(.*?)[ ,;.?$]', message)
673
  if len(img_ids) == 0:
 
738
  outputs=[self.history, self.chatbot, self.text, self.gallery])
739
 
740
  ########################################
741
+ @spaces.GPU(duration=240)
742
+ def run_chat(
743
+ message,
744
+ legacy_image,
745
+ ui_mode,
746
+ use_ace,
747
  extend_prompt,
748
  history,
749
  images,
 
752
  negative_prompt,
753
  cfg_scale,
754
  rescale,
755
+ refiner_prompt,
756
+ refiner_scale,
757
  step,
758
  seed,
759
  output_h,
 
765
  video_fps,
766
  video_seed,
767
  progress=gr.Progress(track_tqdm=True)):
768
+ legacy_img_ids = []
769
+ if ui_mode == 'legacy':
770
+ if legacy_image is not None:
771
+ history, images, img_id = self.add_uploaded_image_to_history(
772
+ legacy_image, history, images)
773
+ legacy_img_ids.append(img_id)
774
  retry_msg = message
775
  gen_id = get_md5(message)[:12]
776
  save_path = os.path.join(self.cache_dir, f'{gen_id}.png')
777
 
778
  img_ids = re.findall('@(.*?)[ ,;.?$]', message)
779
  history_io = None
780
+
781
+ if len(img_ids) < 1:
782
+ img_ids = legacy_img_ids
783
+ for img_id in img_ids:
784
+ if f'@{img_id}' not in message:
785
+ message = f'@{img_id} ' + message
786
+
787
  new_message = message
788
 
789
  if len(img_ids) > 0:
 
815
  history_io = history_result[img_id]
816
 
817
  buffered = io.BytesIO()
818
+ edit_image[0].save(buffered, format='PNG')
819
  img_b64 = base64.b64encode(buffered.getvalue()).decode('utf-8')
820
+ img_str = f'<img src="data:image/png;base64,{img_b64}" style="pointer-events: none;">'
821
  pre_info = f'Received one or more images, so image editing is conducted.\n The first input image @{img_ids[0]} is:\n {img_str}'
822
  else:
823
  pre_info = 'No image ids were found in the provided text prompt, so text-guided image generation is conducted. \n'
 
842
  guide_scale=cfg_scale,
843
  guide_rescale=rescale,
844
  seed=seed,
845
+ refiner_prompt=refiner_prompt,
846
+ refiner_scale=refiner_scale,
847
+ use_ace=use_ace
848
  )
849
 
850
  img = imgs[0]
 
891
  }
892
 
893
  buffered = io.BytesIO()
894
+ img.convert('RGB').save(buffered, format='PNG')
895
  img_b64 = base64.b64encode(buffered.getvalue()).decode('utf-8')
896
+ img_str = f'<img src="data:image/png;base64,{img_b64}" style="pointer-events: none;">'
897
 
898
  history.append(
899
  (message,
 
953
  while len(history) >= self.max_msgs:
954
  history.pop(0)
955
 
956
+ return (history, images, gr.Image(value=save_path),
957
+ history_result, self.get_history(
958
+ history), gr.update(), gr.update(
959
+ visible=False), retry_msg)
960
 
961
  chat_inputs = [
962
+ self.legacy_image_uploader, self.ui_mode, self.use_ace,
963
  self.extend_prompt, self.history, self.images, self.use_history,
964
  self.history_result, self.negative_prompt, self.cfg_scale,
965
+ self.rescale, self.refiner_prompt, self.refiner_scale,
966
+ self.step, self.seed, self.output_height,
967
  self.output_width, self.video_auto, self.video_step,
968
  self.video_frames, self.video_cfg_scale, self.video_fps,
969
  self.video_seed
970
  ]
971
 
972
  chat_outputs = [
973
+ self.history, self.images, self.legacy_image_viewer,
974
+ self.history_result, self.chatbot,
975
  self.text, self.gallery, self.retry_msg
976
  ]
977
 
 
991
  outputs=chat_outputs)
992
 
993
  ########################################
994
+ @spaces.GPU(duration=120)
995
  def run_example(task, img, img_mask, ref1, prompt, seed):
996
  edit_image, edit_image_mask, edit_task = [], [], []
997
  if img is not None:
 
1015
  edit_task.append('')
1016
 
1017
  buffered = io.BytesIO()
1018
+ img.save(buffered, format='PNG')
1019
  img_b64 = base64.b64encode(buffered.getvalue()).decode('utf-8')
1020
+ img_str = f'<img src="data:image/png;base64,{img_b64}" style="pointer-events: none;">'
1021
  pre_info = f'Received one or more images, so image editing is conducted.\n The first input image is:\n {img_str}'
1022
  else:
1023
  pre_info = 'No image ids were found in the provided text prompt, so text-guided image generation is conducted. \n'
 
1033
  prompt=[prompt] * img_num,
1034
  negative_prompt=[''] * img_num,
1035
  seed=seed,
1036
+ refiner_prompt=self.pipe.input.get("refiner_prompt", ""),
1037
+ refiner_scale=self.pipe.input.get("refiner_scale", 0.0),
1038
  )
1039
 
1040
  img = imgs[0]
1041
  buffered = io.BytesIO()
1042
+ img.convert('RGB').save(buffered, format='PNG')
1043
  img_b64 = base64.b64encode(buffered.getvalue()).decode('utf-8')
1044
+ img_str = f'<img src="data:image/png;base64,{img_b64}" style="pointer-events: none;">'
1045
  history = [(prompt,
1046
  f'{pre_info} The generated image is:\n {img_str}')]
1047
  return self.get_history(history), gr.update(value=''), gr.update(
 
1080
  return (gr.update(visible=True,
1081
  scale=1), gr.update(visible=True, scale=1),
1082
  gr.update(visible=True), gr.update(visible=False),
1083
+ gr.update(visible=False), gr.update(visible=False),
1084
+ gr.update(visible=True))
1085
 
1086
  self.upload_btn.click(upload_image,
1087
  inputs=[],
1088
  outputs=[
1089
  self.chat_page, self.editor_page,
1090
  self.upload_tab, self.edit_tab,
1091
+ self.image_view_tab, self.video_view_tab,
1092
+ self.upload_tabs
1093
  ])
1094
 
1095
  ########################################
1096
  def edit_image(evt: gr.SelectData):
1097
  if isinstance(evt.value, str):
1098
  img_b64s = re.findall(
1099
+ '<img src="data:image/png;base64,(.*?)" style="pointer-events: none;">',
1100
  evt.value)
1101
  imgs = [
1102
  Image.open(io.BytesIO(base64.b64decode(copy.deepcopy(i))))
 
1104
  ]
1105
  if len(imgs) > 0:
1106
  if len(imgs) == 2:
1107
+ if self.gradio_version >= '5.0.0':
1108
+ view_img = copy.deepcopy(imgs[-1])
1109
+ else:
1110
+ view_img = copy.deepcopy(imgs)
1111
  edit_img = copy.deepcopy(imgs[-1])
1112
  else:
1113
+ if self.gradio_version >= '5.0.0':
1114
+ view_img = copy.deepcopy(imgs[-1])
1115
+ else:
1116
+ view_img = [
1117
+ copy.deepcopy(imgs[-1]),
1118
+ copy.deepcopy(imgs[-1])
1119
+ ]
1120
  edit_img = copy.deepcopy(imgs[-1])
1121
 
1122
  return (gr.update(visible=True,
 
1125
  gr.update(visible=False), gr.update(visible=True),
1126
  gr.update(visible=True), gr.update(visible=False),
1127
  gr.update(value=edit_img),
1128
+ gr.update(value=view_img), gr.update(value=None),
1129
+ gr.update(visible=True))
1130
  else:
1131
  return (gr.update(), gr.update(), gr.update(), gr.update(),
1132
  gr.update(), gr.update(), gr.update(), gr.update(),
1133
+ gr.update(), gr.update())
1134
  elif isinstance(evt.value, dict) and evt.value.get(
1135
  'component', '') == 'video':
1136
  value = evt.value['value']['video']['path']
 
1138
  scale=1), gr.update(visible=True, scale=1),
1139
  gr.update(visible=False), gr.update(visible=False),
1140
  gr.update(visible=False), gr.update(visible=True),
1141
+ gr.update(), gr.update(), gr.update(value=value),
1142
+ gr.update())
1143
  else:
1144
  return (gr.update(), gr.update(), gr.update(), gr.update(),
1145
  gr.update(), gr.update(), gr.update(), gr.update(),
1146
+ gr.update(), gr.update())
1147
 
1148
  self.chatbot.select(edit_image,
1149
  outputs=[
 
1151
  self.upload_tab, self.edit_tab,
1152
  self.image_view_tab, self.video_view_tab,
1153
  self.image_editor, self.image_viewer,
1154
+ self.video_viewer, self.edit_tabs
1155
  ])
1156
 
1157
+ if self.gradio_version < '5.0.0':
1158
+ self.image_viewer.change(lambda x: x,
1159
+ inputs=self.image_viewer,
1160
+ outputs=self.image_viewer)
1161
 
1162
  ########################################
1163
  def submit_upload_image(image, history, images):
1164
+ history, images, _ = self.add_uploaded_image_to_history(
1165
  image, history, images)
1166
  return gr.update(visible=False), gr.update(
1167
  visible=True), gr.update(
 
1331
  thumbnail.save(thumbnail_path, format='JPEG')
1332
 
1333
  buffered = io.BytesIO()
1334
+ img.convert('RGB').save(buffered, format='PNG')
1335
  img_b64 = base64.b64encode(buffered.getvalue()).decode('utf-8')
1336
+ img_str = f'<img src="data:image/png;base64,{img_b64}" style="pointer-events: none;">'
1337
 
1338
  buffered = io.BytesIO()
1339
+ mask.convert('RGB').save(buffered, format='PNG')
1340
  mask_b64 = base64.b64encode(buffered.getvalue()).decode('utf-8')
1341
+ mask_str = f'<img src="data:image/png;base64,{mask_b64}" style="pointer-events: none;">'
1342
 
1343
  images[img_id] = {
1344
  'image': save_path,
 
1387
  }
1388
 
1389
  buffered = io.BytesIO()
1390
+ img.convert('RGB').save(buffered, format='PNG')
1391
  img_b64 = base64.b64encode(buffered.getvalue()).decode('utf-8')
1392
+ img_str = f'<img src="data:image/png;base64,{img_b64}" style="pointer-events: none;">'
1393
 
1394
  history.append(
1395
  (None,
1396
  f'This is uploaded image:\n {img_str} image ID is: {img_id}'))
1397
+ return history, images, img_id
1398
 
1399
 
1400
 
1401
  if __name__ == '__main__':
1402
+ cfg = "config/chatbot_ui.yaml"
 
1403
  with gr.Blocks() as demo:
1404
  chatbot = ChatBotUI(cfg)
1405
  chatbot.create_ui()
1406
  chatbot.set_callbacks()
1407
+ demo.launch()