pan-yl commited on
Commit
a0edd51
·
1 Parent(s): df46181

modify app.py

Browse files
Files changed (2) hide show
  1. app.py +276 -102
  2. utils.py +95 -0
app.py CHANGED
@@ -1,45 +1,40 @@
1
  # -*- coding: utf-8 -*-
2
  # Copyright (c) Alibaba, Inc. and its affiliates.
3
- import os
4
- import shlex
5
- import subprocess
6
- subprocess.run(shlex.split('pip install flash-attn --no-build-isolation'), env=os.environ | {'FLASH_ATTENTION_SKIP_CUDA_BUILD': "TRUE"})
7
-
8
- import sys
9
- import csv
10
- csv.field_size_limit(sys.maxsize)
11
-
12
- import argparse
13
  import base64
14
  import copy
15
  import glob
16
  import io
17
- import os
18
  import random
19
  import re
 
20
  import string
 
21
  import threading
22
  import spaces
 
 
 
 
23
  import cv2
24
  import gradio as gr
25
  import numpy as np
26
  import torch
27
  import transformers
28
- from diffusers import CogVideoXImageToVideoPipeline
29
- from diffusers.utils import export_to_video
30
- from gradio_imageslider import ImageSlider
31
  from PIL import Image
32
  from transformers import AutoModel, AutoTokenizer
33
 
 
34
  from scepter.modules.utils.config import Config
35
  from scepter.modules.utils.directory import get_md5
36
  from scepter.modules.utils.file_system import FS
37
  from scepter.studio.utils.env import init_env
 
38
 
39
- from infer import ACEInference
40
- from example import get_examples
41
- from utils import load_image
42
 
 
43
 
44
  refresh_sty = '\U0001f504' # 🔄
45
  clear_sty = '\U0001f5d1' # 🗑️
@@ -53,33 +48,43 @@ lock = threading.Lock()
53
 
54
  class ChatBotUI(object):
55
  def __init__(self,
56
- cfg,
 
 
57
  root_work_dir='./'):
 
 
 
 
 
58
 
 
59
  cfg.WORK_DIR = os.path.join(root_work_dir, cfg.WORK_DIR)
60
  if not FS.exists(cfg.WORK_DIR):
61
  FS.make_dir(cfg.WORK_DIR)
62
  cfg = init_env(cfg)
63
  self.cache_dir = cfg.WORK_DIR
64
- self.chatbot_examples = get_examples(self.cache_dir)
65
  self.model_cfg_dir = cfg.MODEL.EDIT_MODEL.MODEL_CFG_DIR
66
  self.model_yamls = glob.glob(os.path.join(self.model_cfg_dir,
67
  '*.yaml'))
68
  self.model_choices = dict()
 
69
  for i in self.model_yamls:
70
- model_name = '.'.join(i.split('/')[-1].split('.')[:-1])
71
- self.model_choices[model_name] = i
72
- print('Models: ', self.model_choices)
73
-
74
- self.model_name = cfg.MODEL.EDIT_MODEL.DEFAULT
75
- assert self.model_name in self.model_choices
76
- model_cfg = Config(load=True,
77
- cfg_file=self.model_choices[self.model_name])
78
  self.pipe = ACEInference()
79
- self.pipe.init_from_cfg(model_cfg)
80
  self.max_msgs = 20
81
-
82
  self.enable_i2v = cfg.get('ENABLE_I2V', False)
 
 
83
  if self.enable_i2v:
84
  self.i2v_model_dir = cfg.MODEL.I2V.MODEL_DIR
85
  self.i2v_model_name = cfg.MODEL.I2V.MODEL_NAME
@@ -170,6 +175,7 @@ class ChatBotUI(object):
170
  ]
171
 
172
  def create_ui(self):
 
173
  css = '.chatbot.prose.md {opacity: 1.0 !important} #chatbot {opacity: 1.0 !important}'
174
  with gr.Blocks(css=css,
175
  title='Chatbot',
@@ -180,7 +186,8 @@ class ChatBotUI(object):
180
  self.history_result = gr.State(value={})
181
  self.retry_msg = gr.State(value='')
182
  with gr.Group():
183
- with gr.Row(equal_height=True):
 
184
  with gr.Column(visible=True) as self.chat_page:
185
  self.chatbot = gr.Chatbot(
186
  height=600,
@@ -195,7 +202,7 @@ class ChatBotUI(object):
195
  size='sm')
196
 
197
  with gr.Column(visible=False) as self.editor_page:
198
- with gr.Tabs():
199
  with gr.Tab(id='ImageUploader',
200
  label='Image Uploader',
201
  visible=True) as self.upload_tab:
@@ -204,7 +211,7 @@ class ChatBotUI(object):
204
  interactive=True,
205
  type='pil',
206
  image_mode='RGB',
207
- sources='upload',
208
  elem_id='image_uploader',
209
  format='png')
210
  with gr.Row():
@@ -212,10 +219,9 @@ class ChatBotUI(object):
212
  value='Submit',
213
  elem_id='upload_submit')
214
  self.ext_btn_1 = gr.Button(value='Exit')
215
-
216
  with gr.Tab(id='ImageEditor',
217
- label='Image Editor',
218
- visible=False) as self.edit_tab:
219
  self.mask_type = gr.Dropdown(
220
  label='Mask Type',
221
  choices=[
@@ -278,13 +284,23 @@ class ChatBotUI(object):
278
  self.ext_btn_2 = gr.Button(value='Exit')
279
 
280
  with gr.Tab(id='ImageViewer',
281
- label='Image Viewer',
282
- visible=False) as self.image_view_tab:
283
- self.image_viewer = ImageSlider(
284
- label='Image',
285
- type='pil',
286
- show_download_button=True,
287
- elem_id='image_viewer')
 
 
 
 
 
 
 
 
 
 
288
 
289
  self.ext_btn_3 = gr.Button(value='Exit')
290
 
@@ -303,11 +319,30 @@ class ChatBotUI(object):
303
 
304
  self.ext_btn_4 = gr.Button(value='Exit')
305
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
306
  with gr.Accordion(label='Setting', open=False):
307
  with gr.Row():
308
  self.model_name_dd = gr.Dropdown(
309
  choices=self.model_choices,
310
- value=self.model_name,
311
  label='Model Version')
312
 
313
  with gr.Row():
@@ -318,39 +353,63 @@ class ChatBotUI(object):
318
  label='Negative Prompt',
319
  container=False)
320
 
 
 
 
 
 
 
 
 
 
 
 
321
  with gr.Row():
322
  with gr.Column(scale=8, min_width=500):
323
  with gr.Row():
324
  self.step = gr.Slider(minimum=1,
325
  maximum=1000,
326
- value=20,
 
327
  label='Sample Step')
328
  self.cfg_scale = gr.Slider(
329
  minimum=1.0,
330
  maximum=20.0,
331
- value=4.5,
 
332
  label='Guidance Scale')
333
  self.rescale = gr.Slider(minimum=0.0,
334
  maximum=1.0,
335
- value=0.5,
 
336
  label='Rescale')
 
 
 
 
 
337
  self.seed = gr.Slider(minimum=-1,
338
  maximum=10000000,
339
  value=-1,
340
  label='Seed')
341
  self.output_height = gr.Slider(
342
  minimum=256,
343
- maximum=1024,
344
- value=512,
 
345
  label='Output Height')
346
  self.output_width = gr.Slider(
347
  minimum=256,
348
- maximum=1024,
349
- value=512,
 
350
  label='Output Width')
351
  with gr.Column(scale=1, min_width=50):
352
  self.use_history = gr.Checkbox(value=False,
353
  label='Use History')
 
 
 
354
  self.video_auto = gr.Checkbox(
355
  value=False,
356
  label='Auto Gen Video',
@@ -387,9 +446,9 @@ class ChatBotUI(object):
387
  visible=True)
388
 
389
  with gr.Row():
390
- inst = """
391
  **Instruction**:
392
-
393
  1. Click 'Upload' button to upload one or more images as input images.
394
  2. Enter '@' in the text box will exhibit all images in the gallery.
395
  3. Select the image you wish to edit from the gallery, and its Image ID will be displayed in the text box.
@@ -399,14 +458,27 @@ class ChatBotUI(object):
399
  6. **Important** To render text on an image, please ensure to include a space between each letter. For instance, "add text 'g i r l' on the mask area of @xxxxx".
400
  7. To implement local editing based on a specified mask, simply click on the image within the chat window to access the image editor. Here, you can draw a mask and then click the 'Submit' button to upload the edited image along with the mask. For inpainting tasks, select the 'Composite' mask type, while for outpainting tasks, choose the 'Outpainting' mask type. For all other local editing tasks, please select the 'Background' mask type.
401
  8. If you find our work valuable, we invite you to refer to the [ACE Page](https://ali-vilab.github.io/ace-page/) for comprehensive information.
402
-
403
  """
404
- gr.Markdown(value=inst)
405
-
 
 
 
 
 
 
 
 
 
 
 
 
 
406
  with gr.Row(variant='panel',
407
  equal_height=True,
408
  show_progress=False):
409
- with gr.Column(scale=1, min_width=100):
410
  self.upload_btn = gr.Button(value=upload_sty +
411
  ' Upload',
412
  variant='secondary')
@@ -416,12 +488,16 @@ class ChatBotUI(object):
416
  label='Instruction',
417
  container=False)
418
  with gr.Column(scale=1, min_width=100):
419
- self.chat_btn = gr.Button(value=chat_sty + ' Chat',
420
  variant='primary')
421
  with gr.Column(scale=1, min_width=100):
422
  self.retry_btn = gr.Button(value=refresh_sty +
423
  ' Retry',
424
  variant='secondary')
 
 
 
 
425
  with gr.Column(scale=(1 if self.enable_i2v else 0),
426
  min_width=0):
427
  self.video_gen_btn = gr.Button(value=video_sty +
@@ -457,19 +533,77 @@ class ChatBotUI(object):
457
  lock.acquire()
458
  del self.pipe
459
  torch.cuda.empty_cache()
460
- model_cfg = Config(load=True,
461
- cfg_file=self.model_choices[model_name])
462
  self.pipe = ACEInference()
463
- self.pipe.init_from_cfg(model_cfg)
464
  self.model_name = model_name
465
  lock.release()
466
 
467
- return model_name, gr.update(), gr.update()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
468
 
469
  self.model_name_dd.change(
470
  change_model,
471
  inputs=[self.model_name_dd],
472
- outputs=[self.model_name_dd, self.chatbot, self.text])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
473
 
474
  ########################################
475
  def generate_gallery(text, images):
@@ -516,7 +650,6 @@ class ChatBotUI(object):
516
  outputs=[self.text, self.gallery])
517
 
518
  ########################################
519
- @spaces.GPU(duration=120)
520
  def generate_video(message,
521
  extend_prompt,
522
  history,
@@ -527,6 +660,9 @@ class ChatBotUI(object):
527
  fps,
528
  seed,
529
  progress=gr.Progress(track_tqdm=True)):
 
 
 
530
  generator = torch.Generator(device='cuda').manual_seed(seed)
531
  img_ids = re.findall('@(.*?)[ ,;.?$]', message)
532
  if len(img_ids) == 0:
@@ -598,7 +734,11 @@ class ChatBotUI(object):
598
 
599
  ########################################
600
  @spaces.GPU(duration=60)
601
- def run_chat(message,
 
 
 
 
602
  extend_prompt,
603
  history,
604
  images,
@@ -607,6 +747,8 @@ class ChatBotUI(object):
607
  negative_prompt,
608
  cfg_scale,
609
  rescale,
 
 
610
  step,
611
  seed,
612
  output_h,
@@ -618,12 +760,25 @@ class ChatBotUI(object):
618
  video_fps,
619
  video_seed,
620
  progress=gr.Progress(track_tqdm=True)):
 
 
 
 
 
 
621
  retry_msg = message
622
  gen_id = get_md5(message)[:12]
623
  save_path = os.path.join(self.cache_dir, f'{gen_id}.png')
624
 
625
  img_ids = re.findall('@(.*?)[ ,;.?$]', message)
626
  history_io = None
 
 
 
 
 
 
 
627
  new_message = message
628
 
629
  if len(img_ids) > 0:
@@ -655,9 +810,9 @@ class ChatBotUI(object):
655
  history_io = history_result[img_id]
656
 
657
  buffered = io.BytesIO()
658
- edit_image[0].save(buffered, format='JPEG')
659
  img_b64 = base64.b64encode(buffered.getvalue()).decode('utf-8')
660
- img_str = f'<img src="data:image/jpg;base64,{img_b64}" style="pointer-events: none;">'
661
  pre_info = f'Received one or more images, so image editing is conducted.\n The first input image @{img_ids[0]} is:\n {img_str}'
662
  else:
663
  pre_info = 'No image ids were found in the provided text prompt, so text-guided image generation is conducted. \n'
@@ -682,6 +837,9 @@ class ChatBotUI(object):
682
  guide_scale=cfg_scale,
683
  guide_rescale=rescale,
684
  seed=seed,
 
 
 
685
  )
686
 
687
  img = imgs[0]
@@ -728,9 +886,9 @@ class ChatBotUI(object):
728
  }
729
 
730
  buffered = io.BytesIO()
731
- img.convert('RGB').save(buffered, format='JPEG')
732
  img_b64 = base64.b64encode(buffered.getvalue()).decode('utf-8')
733
- img_str = f'<img src="data:image/jpg;base64,{img_b64}" style="pointer-events: none;">'
734
 
735
  history.append(
736
  (message,
@@ -790,21 +948,25 @@ class ChatBotUI(object):
790
  while len(history) >= self.max_msgs:
791
  history.pop(0)
792
 
793
- return history, images, history_result, self.get_history(
794
- history), gr.update(value=''), gr.update(
795
- visible=False), retry_msg
 
796
 
797
  chat_inputs = [
 
798
  self.extend_prompt, self.history, self.images, self.use_history,
799
  self.history_result, self.negative_prompt, self.cfg_scale,
800
- self.rescale, self.step, self.seed, self.output_height,
 
801
  self.output_width, self.video_auto, self.video_step,
802
  self.video_frames, self.video_cfg_scale, self.video_fps,
803
  self.video_seed
804
  ]
805
 
806
  chat_outputs = [
807
- self.history, self.images, self.history_result, self.chatbot,
 
808
  self.text, self.gallery, self.retry_msg
809
  ]
810
 
@@ -848,9 +1010,9 @@ class ChatBotUI(object):
848
  edit_task.append('')
849
 
850
  buffered = io.BytesIO()
851
- img.save(buffered, format='JPEG')
852
  img_b64 = base64.b64encode(buffered.getvalue()).decode('utf-8')
853
- img_str = f'<img src="data:image/jpg;base64,{img_b64}" style="pointer-events: none;">'
854
  pre_info = f'Received one or more images, so image editing is conducted.\n The first input image is:\n {img_str}'
855
  else:
856
  pre_info = 'No image ids were found in the provided text prompt, so text-guided image generation is conducted. \n'
@@ -866,13 +1028,15 @@ class ChatBotUI(object):
866
  prompt=[prompt] * img_num,
867
  negative_prompt=[''] * img_num,
868
  seed=seed,
 
 
869
  )
870
 
871
  img = imgs[0]
872
  buffered = io.BytesIO()
873
- img.convert('RGB').save(buffered, format='JPEG')
874
  img_b64 = base64.b64encode(buffered.getvalue()).decode('utf-8')
875
- img_str = f'<img src="data:image/jpg;base64,{img_b64}" style="pointer-events: none;">'
876
  history = [(prompt,
877
  f'{pre_info} The generated image is:\n {img_str}')]
878
  return self.get_history(history), gr.update(value=''), gr.update(
@@ -911,21 +1075,23 @@ class ChatBotUI(object):
911
  return (gr.update(visible=True,
912
  scale=1), gr.update(visible=True, scale=1),
913
  gr.update(visible=True), gr.update(visible=False),
914
- gr.update(visible=False), gr.update(visible=False))
 
915
 
916
  self.upload_btn.click(upload_image,
917
  inputs=[],
918
  outputs=[
919
  self.chat_page, self.editor_page,
920
  self.upload_tab, self.edit_tab,
921
- self.image_view_tab, self.video_view_tab
 
922
  ])
923
 
924
  ########################################
925
  def edit_image(evt: gr.SelectData):
926
  if isinstance(evt.value, str):
927
  img_b64s = re.findall(
928
- '<img src="data:image/jpg;base64,(.*?)" style="pointer-events: none;">',
929
  evt.value)
930
  imgs = [
931
  Image.open(io.BytesIO(base64.b64decode(copy.deepcopy(i))))
@@ -933,13 +1099,19 @@ class ChatBotUI(object):
933
  ]
934
  if len(imgs) > 0:
935
  if len(imgs) == 2:
936
- view_img = copy.deepcopy(imgs)
 
 
 
937
  edit_img = copy.deepcopy(imgs[-1])
938
  else:
939
- view_img = [
940
- copy.deepcopy(imgs[-1]),
941
- copy.deepcopy(imgs[-1])
942
- ]
 
 
 
943
  edit_img = copy.deepcopy(imgs[-1])
944
 
945
  return (gr.update(visible=True,
@@ -948,11 +1120,12 @@ class ChatBotUI(object):
948
  gr.update(visible=False), gr.update(visible=True),
949
  gr.update(visible=True), gr.update(visible=False),
950
  gr.update(value=edit_img),
951
- gr.update(value=view_img), gr.update(value=None))
 
952
  else:
953
  return (gr.update(), gr.update(), gr.update(), gr.update(),
954
  gr.update(), gr.update(), gr.update(), gr.update(),
955
- gr.update())
956
  elif isinstance(evt.value, dict) and evt.value.get(
957
  'component', '') == 'video':
958
  value = evt.value['value']['video']['path']
@@ -960,11 +1133,12 @@ class ChatBotUI(object):
960
  scale=1), gr.update(visible=True, scale=1),
961
  gr.update(visible=False), gr.update(visible=False),
962
  gr.update(visible=False), gr.update(visible=True),
963
- gr.update(), gr.update(), gr.update(value=value))
 
964
  else:
965
  return (gr.update(), gr.update(), gr.update(), gr.update(),
966
  gr.update(), gr.update(), gr.update(), gr.update(),
967
- gr.update())
968
 
969
  self.chatbot.select(edit_image,
970
  outputs=[
@@ -972,16 +1146,17 @@ class ChatBotUI(object):
972
  self.upload_tab, self.edit_tab,
973
  self.image_view_tab, self.video_view_tab,
974
  self.image_editor, self.image_viewer,
975
- self.video_viewer
976
  ])
977
 
978
- self.image_viewer.change(lambda x: x,
979
- inputs=self.image_viewer,
980
- outputs=self.image_viewer)
 
981
 
982
  ########################################
983
  def submit_upload_image(image, history, images):
984
- history, images = self.add_uploaded_image_to_history(
985
  image, history, images)
986
  return gr.update(visible=False), gr.update(
987
  visible=True), gr.update(
@@ -1151,14 +1326,14 @@ class ChatBotUI(object):
1151
  thumbnail.save(thumbnail_path, format='JPEG')
1152
 
1153
  buffered = io.BytesIO()
1154
- img.convert('RGB').save(buffered, format='JPEG')
1155
  img_b64 = base64.b64encode(buffered.getvalue()).decode('utf-8')
1156
- img_str = f'<img src="data:image/jpg;base64,{img_b64}" style="pointer-events: none;">'
1157
 
1158
  buffered = io.BytesIO()
1159
- mask.convert('RGB').save(buffered, format='JPEG')
1160
  mask_b64 = base64.b64encode(buffered.getvalue()).decode('utf-8')
1161
- mask_str = f'<img src="data:image/jpg;base64,{mask_b64}" style="pointer-events: none;">'
1162
 
1163
  images[img_id] = {
1164
  'image': save_path,
@@ -1207,19 +1382,18 @@ class ChatBotUI(object):
1207
  }
1208
 
1209
  buffered = io.BytesIO()
1210
- img.convert('RGB').save(buffered, format='JPEG')
1211
  img_b64 = base64.b64encode(buffered.getvalue()).decode('utf-8')
1212
- img_str = f'<img src="data:image/jpg;base64,{img_b64}" style="pointer-events: none;">'
1213
 
1214
  history.append(
1215
  (None,
1216
  f'This is uploaded image:\n {img_str} image ID is: {img_id}'))
1217
- return history, images
1218
 
1219
 
1220
 
1221
  if __name__ == '__main__':
1222
-
1223
  cfg = Config(cfg_file="config/chatbot_ui.yaml")
1224
  with gr.Blocks() as demo:
1225
  chatbot = ChatBotUI(cfg)
 
1
  # -*- coding: utf-8 -*-
2
  # Copyright (c) Alibaba, Inc. and its affiliates.
 
 
 
 
 
 
 
 
 
 
3
  import base64
4
  import copy
5
  import glob
6
  import io
7
+ import os, csv, sys
8
  import random
9
  import re
10
+ import shlex
11
  import string
12
+ import subprocess
13
  import threading
14
  import spaces
15
+
16
+ subprocess.run(shlex.split('pip install flash-attn --no-build-isolation'),
17
+ env=os.environ | {'FLASH_ATTENTION_SKIP_CUDA_BUILD': "TRUE"})
18
+
19
  import cv2
20
  import gradio as gr
21
  import numpy as np
22
  import torch
23
  import transformers
 
 
 
24
  from PIL import Image
25
  from transformers import AutoModel, AutoTokenizer
26
 
27
+ from scepter.modules.inference.ace_inference import ACEInference
28
  from scepter.modules.utils.config import Config
29
  from scepter.modules.utils.directory import get_md5
30
  from scepter.modules.utils.file_system import FS
31
  from scepter.studio.utils.env import init_env
32
+ from importlib.metadata import version
33
 
34
+ from .example import get_examples
35
+ from .utils import load_image
 
36
 
37
+ csv.field_size_limit(sys.maxsize)
38
 
39
  refresh_sty = '\U0001f504' # 🔄
40
  clear_sty = '\U0001f5d1' # 🗑️
 
48
 
49
  class ChatBotUI(object):
50
  def __init__(self,
51
+ cfg_general_file,
52
+ is_debug=False,
53
+ language='en',
54
  root_work_dir='./'):
55
+ try:
56
+ from diffusers import CogVideoXImageToVideoPipeline
57
+ from diffusers.utils import export_to_video
58
+ except Exception as e:
59
+ print(f"Import diffusers failed, please install or upgrade diffusers. Error information: {e}")
60
 
61
+ cfg = Config(cfg_file=cfg_general_file)
62
  cfg.WORK_DIR = os.path.join(root_work_dir, cfg.WORK_DIR)
63
  if not FS.exists(cfg.WORK_DIR):
64
  FS.make_dir(cfg.WORK_DIR)
65
  cfg = init_env(cfg)
66
  self.cache_dir = cfg.WORK_DIR
67
+ self.chatbot_examples = get_examples(self.cache_dir) if not cfg.get('SKIP_EXAMPLES', False) else []
68
  self.model_cfg_dir = cfg.MODEL.EDIT_MODEL.MODEL_CFG_DIR
69
  self.model_yamls = glob.glob(os.path.join(self.model_cfg_dir,
70
  '*.yaml'))
71
  self.model_choices = dict()
72
+ self.default_model_name = ''
73
  for i in self.model_yamls:
74
+ model_cfg = Config(load=True, cfg_file=i)
75
+ model_name = model_cfg.NAME
76
+ if model_cfg.IS_DEFAULT: self.default_model_name = model_name
77
+ self.model_choices[model_name] = model_cfg
78
+ print('Models: ', self.model_choices.keys())
79
+ assert len(self.model_choices) > 0
80
+ if self.default_model_name == "": self.default_model_name = self.model_choices.keys()[0]
81
+ self.model_name = self.default_model_name
82
  self.pipe = ACEInference()
83
+ self.pipe.init_from_cfg(self.model_choices[self.default_model_name])
84
  self.max_msgs = 20
 
85
  self.enable_i2v = cfg.get('ENABLE_I2V', False)
86
+ self.gradio_version = version('gradio')
87
+
88
  if self.enable_i2v:
89
  self.i2v_model_dir = cfg.MODEL.I2V.MODEL_DIR
90
  self.i2v_model_name = cfg.MODEL.I2V.MODEL_NAME
 
175
  ]
176
 
177
  def create_ui(self):
178
+
179
  css = '.chatbot.prose.md {opacity: 1.0 !important} #chatbot {opacity: 1.0 !important}'
180
  with gr.Blocks(css=css,
181
  title='Chatbot',
 
186
  self.history_result = gr.State(value={})
187
  self.retry_msg = gr.State(value='')
188
  with gr.Group():
189
+ self.ui_mode = gr.State(value='legacy')
190
+ with gr.Row(equal_height=True, visible=False) as self.chat_group:
191
  with gr.Column(visible=True) as self.chat_page:
192
  self.chatbot = gr.Chatbot(
193
  height=600,
 
202
  size='sm')
203
 
204
  with gr.Column(visible=False) as self.editor_page:
205
+ with gr.Tabs(visible=False) as self.upload_tabs:
206
  with gr.Tab(id='ImageUploader',
207
  label='Image Uploader',
208
  visible=True) as self.upload_tab:
 
211
  interactive=True,
212
  type='pil',
213
  image_mode='RGB',
214
+ sources=['upload'],
215
  elem_id='image_uploader',
216
  format='png')
217
  with gr.Row():
 
219
  value='Submit',
220
  elem_id='upload_submit')
221
  self.ext_btn_1 = gr.Button(value='Exit')
222
+ with gr.Tabs(visible=False) as self.edit_tabs:
223
  with gr.Tab(id='ImageEditor',
224
+ label='Image Editor') as self.edit_tab:
 
225
  self.mask_type = gr.Dropdown(
226
  label='Mask Type',
227
  choices=[
 
284
  self.ext_btn_2 = gr.Button(value='Exit')
285
 
286
  with gr.Tab(id='ImageViewer',
287
+ label='Image Viewer') as self.image_view_tab:
288
+ if self.gradio_version >= '5.0.0':
289
+ self.image_viewer = gr.Image(
290
+ label='Image',
291
+ type='pil',
292
+ show_download_button=True,
293
+ elem_id='image_viewer')
294
+ else:
295
+ try:
296
+ from gradio_imageslider import ImageSlider
297
+ except Exception as e:
298
+ print(f"Import gradio_imageslider failed, please install.")
299
+ self.image_viewer = ImageSlider(
300
+ label='Image',
301
+ type='pil',
302
+ show_download_button=True,
303
+ elem_id='image_viewer')
304
 
305
  self.ext_btn_3 = gr.Button(value='Exit')
306
 
 
319
 
320
  self.ext_btn_4 = gr.Button(value='Exit')
321
 
322
+ with gr.Row(equal_height=True, visible=True) as self.legacy_group:
323
+ with gr.Column():
324
+ self.legacy_image_uploader = gr.Image(
325
+ height=550,
326
+ interactive=True,
327
+ type='pil',
328
+ image_mode='RGB',
329
+ elem_id='legacy_image_uploader',
330
+ format='png')
331
+ with gr.Column():
332
+ self.legacy_image_viewer = gr.Image(
333
+ label='Image',
334
+ height=550,
335
+ type='pil',
336
+ interactive=False,
337
+ show_download_button=True,
338
+ elem_id='image_viewer')
339
+
340
+
341
  with gr.Accordion(label='Setting', open=False):
342
  with gr.Row():
343
  self.model_name_dd = gr.Dropdown(
344
  choices=self.model_choices,
345
+ value=self.default_model_name,
346
  label='Model Version')
347
 
348
  with gr.Row():
 
353
  label='Negative Prompt',
354
  container=False)
355
 
356
+ with gr.Row():
357
+ # REFINER_PROMPT
358
+ self.refiner_prompt = gr.Textbox(
359
+ value=self.pipe.input.get("refiner_prompt", ""),
360
+ visible=self.pipe.input.get("refiner_prompt", None) is not None,
361
+ placeholder=
362
+ 'Prompt used for refiner',
363
+ label='Refiner Prompt',
364
+ container=False)
365
+
366
+
367
  with gr.Row():
368
  with gr.Column(scale=8, min_width=500):
369
  with gr.Row():
370
  self.step = gr.Slider(minimum=1,
371
  maximum=1000,
372
+ value=self.pipe.input.get("sample_steps", 20),
373
+ visible=self.pipe.input.get("sample_steps", None) is not None,
374
  label='Sample Step')
375
  self.cfg_scale = gr.Slider(
376
  minimum=1.0,
377
  maximum=20.0,
378
+ value=self.pipe.input.get("guide_scale", 4.5),
379
+ visible=self.pipe.input.get("guide_scale", None) is not None,
380
  label='Guidance Scale')
381
  self.rescale = gr.Slider(minimum=0.0,
382
  maximum=1.0,
383
+ value=self.pipe.input.get("guide_rescale", 0.5),
384
+ visible=self.pipe.input.get("guide_rescale", None) is not None,
385
  label='Rescale')
386
+ self.refiner_scale = gr.Slider(minimum=-0.1,
387
+ maximum=1.0,
388
+ value=self.pipe.input.get("refiner_scale", 0.5),
389
+ visible=self.pipe.input.get("refiner_scale", None) is not None,
390
+ label='Refiner Scale')
391
  self.seed = gr.Slider(minimum=-1,
392
  maximum=10000000,
393
  value=-1,
394
  label='Seed')
395
  self.output_height = gr.Slider(
396
  minimum=256,
397
+ maximum=1440,
398
+ value=self.pipe.input.get("output_height", 1024),
399
+ visible=self.pipe.input.get("output_height", None) is not None,
400
  label='Output Height')
401
  self.output_width = gr.Slider(
402
  minimum=256,
403
+ maximum=1440,
404
+ value=self.pipe.input.get("output_width", 1024),
405
+ visible=self.pipe.input.get("output_width", None) is not None,
406
  label='Output Width')
407
  with gr.Column(scale=1, min_width=50):
408
  self.use_history = gr.Checkbox(value=False,
409
  label='Use History')
410
+ self.use_ace = gr.Checkbox(value=self.pipe.input.get("use_ace", True),
411
+ visible=self.pipe.input.get("use_ace", None) is not None,
412
+ label='Use ACE')
413
  self.video_auto = gr.Checkbox(
414
  value=False,
415
  label='Auto Gen Video',
 
446
  visible=True)
447
 
448
  with gr.Row():
449
+ self.chatbot_inst = """
450
  **Instruction**:
451
+
452
  1. Click 'Upload' button to upload one or more images as input images.
453
  2. Enter '@' in the text box will exhibit all images in the gallery.
454
  3. Select the image you wish to edit from the gallery, and its Image ID will be displayed in the text box.
 
458
  6. **Important** To render text on an image, please ensure to include a space between each letter. For instance, "add text 'g i r l' on the mask area of @xxxxx".
459
  7. To implement local editing based on a specified mask, simply click on the image within the chat window to access the image editor. Here, you can draw a mask and then click the 'Submit' button to upload the edited image along with the mask. For inpainting tasks, select the 'Composite' mask type, while for outpainting tasks, choose the 'Outpainting' mask type. For all other local editing tasks, please select the 'Background' mask type.
460
  8. If you find our work valuable, we invite you to refer to the [ACE Page](https://ali-vilab.github.io/ace-page/) for comprehensive information.
461
+
462
  """
463
+
464
+ self.legacy_inst = """
465
+ **Instruction**:
466
+
467
+ 1. You can edit the image by uploading it; if no image is uploaded, an image will be generated from text..
468
+ 2. Enter '@' in the text box will exhibit all images in the gallery.
469
+ 3. Select the image you wish to edit from the gallery, and its Image ID will be displayed in the text box.
470
+ 4. **Important** To render text on an image, please ensure to include a space between each letter. For instance, "add text 'g i r l' on the mask area of @xxxxx".
471
+ 5. To perform multi-step editing, partial editing, inpainting, outpainting, and other operations, please click the Chatbot Checkbox to enable the conversational editing mode and follow the relevant instructions..
472
+ 6. If you find our work valuable, we invite you to refer to the [ACE Page](https://ali-vilab.github.io/ace-page/) for comprehensive information.
473
+
474
+ """
475
+
476
+ self.instruction = gr.Markdown(value=self.legacy_inst)
477
+
478
  with gr.Row(variant='panel',
479
  equal_height=True,
480
  show_progress=False):
481
+ with gr.Column(scale=1, min_width=100, visible=False) as self.upload_panel:
482
  self.upload_btn = gr.Button(value=upload_sty +
483
  ' Upload',
484
  variant='secondary')
 
488
  label='Instruction',
489
  container=False)
490
  with gr.Column(scale=1, min_width=100):
491
+ self.chat_btn = gr.Button(value='Generate',
492
  variant='primary')
493
  with gr.Column(scale=1, min_width=100):
494
  self.retry_btn = gr.Button(value=refresh_sty +
495
  ' Retry',
496
  variant='secondary')
497
+ with gr.Column(scale=1, min_width=100):
498
+ self.mode_checkbox = gr.Checkbox(
499
+ value=False,
500
+ label='ChatBot')
501
  with gr.Column(scale=(1 if self.enable_i2v else 0),
502
  min_width=0):
503
  self.video_gen_btn = gr.Button(value=video_sty +
 
533
  lock.acquire()
534
  del self.pipe
535
  torch.cuda.empty_cache()
 
 
536
  self.pipe = ACEInference()
537
+ self.pipe.init_from_cfg(self.model_choices[model_name])
538
  self.model_name = model_name
539
  lock.release()
540
 
541
+ return (model_name, gr.update(), gr.update(),
542
+ gr.Slider(
543
+ value=self.pipe.input.get("sample_steps", 20),
544
+ visible=self.pipe.input.get("sample_steps", None) is not None),
545
+ gr.Slider(
546
+ value=self.pipe.input.get("guide_scale", 4.5),
547
+ visible=self.pipe.input.get("guide_scale", None) is not None),
548
+ gr.Slider(
549
+ value=self.pipe.input.get("guide_rescale", 0.5),
550
+ visible=self.pipe.input.get("guide_rescale", None) is not None),
551
+ gr.Slider(
552
+ value=self.pipe.input.get("output_height", 1024),
553
+ visible=self.pipe.input.get("output_height", None) is not None),
554
+ gr.Slider(
555
+ value=self.pipe.input.get("output_width", 1024),
556
+ visible=self.pipe.input.get("output_width", None) is not None),
557
+ gr.Textbox(
558
+ value=self.pipe.input.get("refiner_prompt", ""),
559
+ visible=self.pipe.input.get("refiner_prompt", None) is not None),
560
+ gr.Slider(
561
+ value=self.pipe.input.get("refiner_scale", 0.5),
562
+ visible=self.pipe.input.get("refiner_scale", None) is not None
563
+ ),
564
+ gr.Checkbox(
565
+ value=self.pipe.input.get("use_ace", True),
566
+ visible=self.pipe.input.get("use_ace", None) is not None
567
+ )
568
+ )
569
 
570
  self.model_name_dd.change(
571
  change_model,
572
  inputs=[self.model_name_dd],
573
+ outputs=[
574
+ self.model_name_dd, self.chatbot, self.text,
575
+ self.step,
576
+ self.cfg_scale, self.rescale, self.output_height,
577
+ self.output_width, self.refiner_prompt, self.refiner_scale,
578
+ self.use_ace])
579
+
580
+
581
+ def mode_change(mode_check):
582
+ if mode_check:
583
+ # ChatBot
584
+ return (
585
+ gr.Row(visible=False),
586
+ gr.Row(visible=True),
587
+ gr.Button(value='Generate'),
588
+ gr.State(value='chatbot'),
589
+ gr.Column(visible=True),
590
+ gr.Markdown(value=self.chatbot_inst)
591
+ )
592
+ else:
593
+ # Legacy
594
+ return (
595
+ gr.Row(visible=True),
596
+ gr.Row(visible=False),
597
+ gr.Button(value=chat_sty + ' Chat'),
598
+ gr.State(value='legacy'),
599
+ gr.Column(visible=False),
600
+ gr.Markdown(value=self.legacy_inst)
601
+ )
602
+ self.mode_checkbox.change(mode_change, inputs=[self.mode_checkbox],
603
+ outputs=[self.legacy_group, self.chat_group,
604
+ self.chat_btn, self.ui_mode,
605
+ self.upload_panel, self.instruction])
606
+
607
 
608
  ########################################
609
  def generate_gallery(text, images):
 
650
  outputs=[self.text, self.gallery])
651
 
652
  ########################################
 
653
  def generate_video(message,
654
  extend_prompt,
655
  history,
 
660
  fps,
661
  seed,
662
  progress=gr.Progress(track_tqdm=True)):
663
+
664
+ from diffusers.utils import export_to_video
665
+
666
  generator = torch.Generator(device='cuda').manual_seed(seed)
667
  img_ids = re.findall('@(.*?)[ ,;.?$]', message)
668
  if len(img_ids) == 0:
 
734
 
735
  ########################################
736
  @spaces.GPU(duration=60)
737
+ def run_chat(
738
+ message,
739
+ legacy_image,
740
+ ui_mode,
741
+ use_ace,
742
  extend_prompt,
743
  history,
744
  images,
 
747
  negative_prompt,
748
  cfg_scale,
749
  rescale,
750
+ refiner_prompt,
751
+ refiner_scale,
752
  step,
753
  seed,
754
  output_h,
 
760
  video_fps,
761
  video_seed,
762
  progress=gr.Progress(track_tqdm=True)):
763
+ legacy_img_ids = []
764
+ if ui_mode == 'legacy':
765
+ if legacy_image is not None:
766
+ history, images, img_id = self.add_uploaded_image_to_history(
767
+ legacy_image, history, images)
768
+ legacy_img_ids.append(img_id)
769
  retry_msg = message
770
  gen_id = get_md5(message)[:12]
771
  save_path = os.path.join(self.cache_dir, f'{gen_id}.png')
772
 
773
  img_ids = re.findall('@(.*?)[ ,;.?$]', message)
774
  history_io = None
775
+
776
+ if len(img_ids) < 1:
777
+ img_ids = legacy_img_ids
778
+ for img_id in img_ids:
779
+ if f'@{img_id}' not in message:
780
+ message = f'@{img_id} ' + message
781
+
782
  new_message = message
783
 
784
  if len(img_ids) > 0:
 
810
  history_io = history_result[img_id]
811
 
812
  buffered = io.BytesIO()
813
+ edit_image[0].save(buffered, format='PNG')
814
  img_b64 = base64.b64encode(buffered.getvalue()).decode('utf-8')
815
+ img_str = f'<img src="data:image/png;base64,{img_b64}" style="pointer-events: none;">'
816
  pre_info = f'Received one or more images, so image editing is conducted.\n The first input image @{img_ids[0]} is:\n {img_str}'
817
  else:
818
  pre_info = 'No image ids were found in the provided text prompt, so text-guided image generation is conducted. \n'
 
837
  guide_scale=cfg_scale,
838
  guide_rescale=rescale,
839
  seed=seed,
840
+ refiner_prompt=refiner_prompt,
841
+ refiner_scale=refiner_scale,
842
+ use_ace=use_ace
843
  )
844
 
845
  img = imgs[0]
 
886
  }
887
 
888
  buffered = io.BytesIO()
889
+ img.convert('RGB').save(buffered, format='PNG')
890
  img_b64 = base64.b64encode(buffered.getvalue()).decode('utf-8')
891
+ img_str = f'<img src="data:image/png;base64,{img_b64}" style="pointer-events: none;">'
892
 
893
  history.append(
894
  (message,
 
948
  while len(history) >= self.max_msgs:
949
  history.pop(0)
950
 
951
+ return (history, images, gr.Image(value=save_path),
952
+ history_result, self.get_history(
953
+ history), gr.update(), gr.update(
954
+ visible=False), retry_msg)
955
 
956
  chat_inputs = [
957
+ self.legacy_image_uploader, self.ui_mode, self.use_ace,
958
  self.extend_prompt, self.history, self.images, self.use_history,
959
  self.history_result, self.negative_prompt, self.cfg_scale,
960
+ self.rescale, self.refiner_prompt, self.refiner_scale,
961
+ self.step, self.seed, self.output_height,
962
  self.output_width, self.video_auto, self.video_step,
963
  self.video_frames, self.video_cfg_scale, self.video_fps,
964
  self.video_seed
965
  ]
966
 
967
  chat_outputs = [
968
+ self.history, self.images, self.legacy_image_viewer,
969
+ self.history_result, self.chatbot,
970
  self.text, self.gallery, self.retry_msg
971
  ]
972
 
 
1010
  edit_task.append('')
1011
 
1012
  buffered = io.BytesIO()
1013
+ img.save(buffered, format='PNG')
1014
  img_b64 = base64.b64encode(buffered.getvalue()).decode('utf-8')
1015
+ img_str = f'<img src="data:image/png;base64,{img_b64}" style="pointer-events: none;">'
1016
  pre_info = f'Received one or more images, so image editing is conducted.\n The first input image is:\n {img_str}'
1017
  else:
1018
  pre_info = 'No image ids were found in the provided text prompt, so text-guided image generation is conducted. \n'
 
1028
  prompt=[prompt] * img_num,
1029
  negative_prompt=[''] * img_num,
1030
  seed=seed,
1031
+ refiner_prompt=self.pipe.input.get("refiner_prompt", ""),
1032
+ refiner_scale=self.pipe.input.get("refiner_scale", 0.0),
1033
  )
1034
 
1035
  img = imgs[0]
1036
  buffered = io.BytesIO()
1037
+ img.convert('RGB').save(buffered, format='PNG')
1038
  img_b64 = base64.b64encode(buffered.getvalue()).decode('utf-8')
1039
+ img_str = f'<img src="data:image/png;base64,{img_b64}" style="pointer-events: none;">'
1040
  history = [(prompt,
1041
  f'{pre_info} The generated image is:\n {img_str}')]
1042
  return self.get_history(history), gr.update(value=''), gr.update(
 
1075
  return (gr.update(visible=True,
1076
  scale=1), gr.update(visible=True, scale=1),
1077
  gr.update(visible=True), gr.update(visible=False),
1078
+ gr.update(visible=False), gr.update(visible=False),
1079
+ gr.update(visible=True))
1080
 
1081
  self.upload_btn.click(upload_image,
1082
  inputs=[],
1083
  outputs=[
1084
  self.chat_page, self.editor_page,
1085
  self.upload_tab, self.edit_tab,
1086
+ self.image_view_tab, self.video_view_tab,
1087
+ self.upload_tabs
1088
  ])
1089
 
1090
  ########################################
1091
  def edit_image(evt: gr.SelectData):
1092
  if isinstance(evt.value, str):
1093
  img_b64s = re.findall(
1094
+ '<img src="data:image/png;base64,(.*?)" style="pointer-events: none;">',
1095
  evt.value)
1096
  imgs = [
1097
  Image.open(io.BytesIO(base64.b64decode(copy.deepcopy(i))))
 
1099
  ]
1100
  if len(imgs) > 0:
1101
  if len(imgs) == 2:
1102
+ if self.gradio_version >= '5.0.0':
1103
+ view_img = copy.deepcopy(imgs[-1])
1104
+ else:
1105
+ view_img = copy.deepcopy(imgs)
1106
  edit_img = copy.deepcopy(imgs[-1])
1107
  else:
1108
+ if self.gradio_version >= '5.0.0':
1109
+ view_img = copy.deepcopy(imgs[-1])
1110
+ else:
1111
+ view_img = [
1112
+ copy.deepcopy(imgs[-1]),
1113
+ copy.deepcopy(imgs[-1])
1114
+ ]
1115
  edit_img = copy.deepcopy(imgs[-1])
1116
 
1117
  return (gr.update(visible=True,
 
1120
  gr.update(visible=False), gr.update(visible=True),
1121
  gr.update(visible=True), gr.update(visible=False),
1122
  gr.update(value=edit_img),
1123
+ gr.update(value=view_img), gr.update(value=None),
1124
+ gr.update(visible=True))
1125
  else:
1126
  return (gr.update(), gr.update(), gr.update(), gr.update(),
1127
  gr.update(), gr.update(), gr.update(), gr.update(),
1128
+ gr.update(), gr.update())
1129
  elif isinstance(evt.value, dict) and evt.value.get(
1130
  'component', '') == 'video':
1131
  value = evt.value['value']['video']['path']
 
1133
  scale=1), gr.update(visible=True, scale=1),
1134
  gr.update(visible=False), gr.update(visible=False),
1135
  gr.update(visible=False), gr.update(visible=True),
1136
+ gr.update(), gr.update(), gr.update(value=value),
1137
+ gr.update())
1138
  else:
1139
  return (gr.update(), gr.update(), gr.update(), gr.update(),
1140
  gr.update(), gr.update(), gr.update(), gr.update(),
1141
+ gr.update(), gr.update())
1142
 
1143
  self.chatbot.select(edit_image,
1144
  outputs=[
 
1146
  self.upload_tab, self.edit_tab,
1147
  self.image_view_tab, self.video_view_tab,
1148
  self.image_editor, self.image_viewer,
1149
+ self.video_viewer, self.edit_tabs
1150
  ])
1151
 
1152
+ if self.gradio_version < '5.0.0':
1153
+ self.image_viewer.change(lambda x: x,
1154
+ inputs=self.image_viewer,
1155
+ outputs=self.image_viewer)
1156
 
1157
  ########################################
1158
  def submit_upload_image(image, history, images):
1159
+ history, images, _ = self.add_uploaded_image_to_history(
1160
  image, history, images)
1161
  return gr.update(visible=False), gr.update(
1162
  visible=True), gr.update(
 
1326
  thumbnail.save(thumbnail_path, format='JPEG')
1327
 
1328
  buffered = io.BytesIO()
1329
+ img.convert('RGB').save(buffered, format='PNG')
1330
  img_b64 = base64.b64encode(buffered.getvalue()).decode('utf-8')
1331
+ img_str = f'<img src="data:image/png;base64,{img_b64}" style="pointer-events: none;">'
1332
 
1333
  buffered = io.BytesIO()
1334
+ mask.convert('RGB').save(buffered, format='PNG')
1335
  mask_b64 = base64.b64encode(buffered.getvalue()).decode('utf-8')
1336
+ mask_str = f'<img src="data:image/png;base64,{mask_b64}" style="pointer-events: none;">'
1337
 
1338
  images[img_id] = {
1339
  'image': save_path,
 
1382
  }
1383
 
1384
  buffered = io.BytesIO()
1385
+ img.convert('RGB').save(buffered, format='PNG')
1386
  img_b64 = base64.b64encode(buffered.getvalue()).decode('utf-8')
1387
+ img_str = f'<img src="data:image/png;base64,{img_b64}" style="pointer-events: none;">'
1388
 
1389
  history.append(
1390
  (None,
1391
  f'This is uploaded image:\n {img_str} image ID is: {img_id}'))
1392
+ return history, images, img_id
1393
 
1394
 
1395
 
1396
  if __name__ == '__main__':
 
1397
  cfg = Config(cfg_file="config/chatbot_ui.yaml")
1398
  with gr.Blocks() as demo:
1399
  chatbot = ChatBotUI(cfg)
utils.py ADDED
@@ -0,0 +1,95 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #copyright (c) Alibaba, Inc. and its affiliates.
2
+ import torch
3
+ import torchvision.transforms as T
4
+ from PIL import Image
5
+ from torchvision.transforms.functional import InterpolationMode
6
+
7
+ IMAGENET_MEAN = (0.485, 0.456, 0.406)
8
+ IMAGENET_STD = (0.229, 0.224, 0.225)
9
+
10
+
11
+ def build_transform(input_size):
12
+ MEAN, STD = IMAGENET_MEAN, IMAGENET_STD
13
+ transform = T.Compose([
14
+ T.Lambda(lambda img: img.convert('RGB') if img.mode != 'RGB' else img),
15
+ T.Resize((input_size, input_size),
16
+ interpolation=InterpolationMode.BICUBIC),
17
+ T.ToTensor(),
18
+ T.Normalize(mean=MEAN, std=STD)
19
+ ])
20
+ return transform
21
+
22
+
23
+ def find_closest_aspect_ratio(aspect_ratio, target_ratios, width, height,
24
+ image_size):
25
+ best_ratio_diff = float('inf')
26
+ best_ratio = (1, 1)
27
+ area = width * height
28
+ for ratio in target_ratios:
29
+ target_aspect_ratio = ratio[0] / ratio[1]
30
+ ratio_diff = abs(aspect_ratio - target_aspect_ratio)
31
+ if ratio_diff < best_ratio_diff:
32
+ best_ratio_diff = ratio_diff
33
+ best_ratio = ratio
34
+ elif ratio_diff == best_ratio_diff:
35
+ if area > 0.5 * image_size * image_size * ratio[0] * ratio[1]:
36
+ best_ratio = ratio
37
+ return best_ratio
38
+
39
+
40
+ def dynamic_preprocess(image,
41
+ min_num=1,
42
+ max_num=12,
43
+ image_size=448,
44
+ use_thumbnail=False):
45
+ orig_width, orig_height = image.size
46
+ aspect_ratio = orig_width / orig_height
47
+
48
+ # calculate the existing image aspect ratio
49
+ target_ratios = set((i, j) for n in range(min_num, max_num + 1)
50
+ for i in range(1, n + 1) for j in range(1, n + 1)
51
+ if i * j <= max_num and i * j >= min_num)
52
+ target_ratios = sorted(target_ratios, key=lambda x: x[0] * x[1])
53
+
54
+ # find the closest aspect ratio to the target
55
+ target_aspect_ratio = find_closest_aspect_ratio(aspect_ratio,
56
+ target_ratios, orig_width,
57
+ orig_height, image_size)
58
+
59
+ # calculate the target width and height
60
+ target_width = image_size * target_aspect_ratio[0]
61
+ target_height = image_size * target_aspect_ratio[1]
62
+ blocks = target_aspect_ratio[0] * target_aspect_ratio[1]
63
+
64
+ # resize the image
65
+ resized_img = image.resize((target_width, target_height))
66
+ processed_images = []
67
+ for i in range(blocks):
68
+ box = ((i % (target_width // image_size)) * image_size,
69
+ (i // (target_width // image_size)) * image_size,
70
+ ((i % (target_width // image_size)) + 1) * image_size,
71
+ ((i // (target_width // image_size)) + 1) * image_size)
72
+ # split the image
73
+ split_img = resized_img.crop(box)
74
+ processed_images.append(split_img)
75
+ assert len(processed_images) == blocks
76
+ if use_thumbnail and len(processed_images) != 1:
77
+ thumbnail_img = image.resize((image_size, image_size))
78
+ processed_images.append(thumbnail_img)
79
+ return processed_images
80
+
81
+
82
+ def load_image(image_file, input_size=448, max_num=12):
83
+ if isinstance(image_file, str):
84
+ image = Image.open(image_file).convert('RGB')
85
+ else:
86
+ image = image_file
87
+ transform = build_transform(input_size=input_size)
88
+ images = dynamic_preprocess(image,
89
+ image_size=input_size,
90
+ use_thumbnail=True,
91
+ max_num=max_num)
92
+ pixel_values = [transform(image) for image in images]
93
+ pixel_values = torch.stack(pixel_values)
94
+ return pixel_values
95
+