Spaces:
Running
on
Zero
Running
on
Zero
modify app.py
Browse files
app.py
CHANGED
@@ -1,45 +1,40 @@
|
|
1 |
# -*- coding: utf-8 -*-
|
2 |
# Copyright (c) Alibaba, Inc. and its affiliates.
|
3 |
-
import os
|
4 |
-
import shlex
|
5 |
-
import subprocess
|
6 |
-
subprocess.run(shlex.split('pip install flash-attn --no-build-isolation'), env=os.environ | {'FLASH_ATTENTION_SKIP_CUDA_BUILD': "TRUE"})
|
7 |
-
|
8 |
-
import sys
|
9 |
-
import csv
|
10 |
-
csv.field_size_limit(sys.maxsize)
|
11 |
-
|
12 |
-
import argparse
|
13 |
import base64
|
14 |
import copy
|
15 |
import glob
|
16 |
import io
|
17 |
-
import os
|
18 |
import random
|
19 |
import re
|
|
|
20 |
import string
|
|
|
21 |
import threading
|
22 |
import spaces
|
|
|
|
|
|
|
|
|
23 |
import cv2
|
24 |
import gradio as gr
|
25 |
import numpy as np
|
26 |
import torch
|
27 |
import transformers
|
28 |
-
from diffusers import CogVideoXImageToVideoPipeline
|
29 |
-
from diffusers.utils import export_to_video
|
30 |
-
from gradio_imageslider import ImageSlider
|
31 |
from PIL import Image
|
32 |
from transformers import AutoModel, AutoTokenizer
|
33 |
|
|
|
34 |
from scepter.modules.utils.config import Config
|
35 |
from scepter.modules.utils.directory import get_md5
|
36 |
from scepter.modules.utils.file_system import FS
|
37 |
from scepter.studio.utils.env import init_env
|
|
|
38 |
|
39 |
-
from
|
40 |
-
from
|
41 |
-
from utils import load_image
|
42 |
|
|
|
43 |
|
44 |
refresh_sty = '\U0001f504' # 🔄
|
45 |
clear_sty = '\U0001f5d1' # 🗑️
|
@@ -53,33 +48,43 @@ lock = threading.Lock()
|
|
53 |
|
54 |
class ChatBotUI(object):
|
55 |
def __init__(self,
|
56 |
-
|
|
|
|
|
57 |
root_work_dir='./'):
|
|
|
|
|
|
|
|
|
|
|
58 |
|
|
|
59 |
cfg.WORK_DIR = os.path.join(root_work_dir, cfg.WORK_DIR)
|
60 |
if not FS.exists(cfg.WORK_DIR):
|
61 |
FS.make_dir(cfg.WORK_DIR)
|
62 |
cfg = init_env(cfg)
|
63 |
self.cache_dir = cfg.WORK_DIR
|
64 |
-
self.chatbot_examples = get_examples(self.cache_dir)
|
65 |
self.model_cfg_dir = cfg.MODEL.EDIT_MODEL.MODEL_CFG_DIR
|
66 |
self.model_yamls = glob.glob(os.path.join(self.model_cfg_dir,
|
67 |
'*.yaml'))
|
68 |
self.model_choices = dict()
|
|
|
69 |
for i in self.model_yamls:
|
70 |
-
|
71 |
-
|
72 |
-
|
73 |
-
|
74 |
-
|
75 |
-
assert self.
|
76 |
-
|
77 |
-
|
78 |
self.pipe = ACEInference()
|
79 |
-
self.pipe.init_from_cfg(
|
80 |
self.max_msgs = 20
|
81 |
-
|
82 |
self.enable_i2v = cfg.get('ENABLE_I2V', False)
|
|
|
|
|
83 |
if self.enable_i2v:
|
84 |
self.i2v_model_dir = cfg.MODEL.I2V.MODEL_DIR
|
85 |
self.i2v_model_name = cfg.MODEL.I2V.MODEL_NAME
|
@@ -170,6 +175,7 @@ class ChatBotUI(object):
|
|
170 |
]
|
171 |
|
172 |
def create_ui(self):
|
|
|
173 |
css = '.chatbot.prose.md {opacity: 1.0 !important} #chatbot {opacity: 1.0 !important}'
|
174 |
with gr.Blocks(css=css,
|
175 |
title='Chatbot',
|
@@ -180,7 +186,8 @@ class ChatBotUI(object):
|
|
180 |
self.history_result = gr.State(value={})
|
181 |
self.retry_msg = gr.State(value='')
|
182 |
with gr.Group():
|
183 |
-
|
|
|
184 |
with gr.Column(visible=True) as self.chat_page:
|
185 |
self.chatbot = gr.Chatbot(
|
186 |
height=600,
|
@@ -195,7 +202,7 @@ class ChatBotUI(object):
|
|
195 |
size='sm')
|
196 |
|
197 |
with gr.Column(visible=False) as self.editor_page:
|
198 |
-
with gr.Tabs():
|
199 |
with gr.Tab(id='ImageUploader',
|
200 |
label='Image Uploader',
|
201 |
visible=True) as self.upload_tab:
|
@@ -204,7 +211,7 @@ class ChatBotUI(object):
|
|
204 |
interactive=True,
|
205 |
type='pil',
|
206 |
image_mode='RGB',
|
207 |
-
sources='upload',
|
208 |
elem_id='image_uploader',
|
209 |
format='png')
|
210 |
with gr.Row():
|
@@ -212,10 +219,9 @@ class ChatBotUI(object):
|
|
212 |
value='Submit',
|
213 |
elem_id='upload_submit')
|
214 |
self.ext_btn_1 = gr.Button(value='Exit')
|
215 |
-
|
216 |
with gr.Tab(id='ImageEditor',
|
217 |
-
label='Image Editor'
|
218 |
-
visible=False) as self.edit_tab:
|
219 |
self.mask_type = gr.Dropdown(
|
220 |
label='Mask Type',
|
221 |
choices=[
|
@@ -278,13 +284,23 @@ class ChatBotUI(object):
|
|
278 |
self.ext_btn_2 = gr.Button(value='Exit')
|
279 |
|
280 |
with gr.Tab(id='ImageViewer',
|
281 |
-
label='Image Viewer'
|
282 |
-
|
283 |
-
|
284 |
-
|
285 |
-
|
286 |
-
|
287 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
288 |
|
289 |
self.ext_btn_3 = gr.Button(value='Exit')
|
290 |
|
@@ -303,11 +319,30 @@ class ChatBotUI(object):
|
|
303 |
|
304 |
self.ext_btn_4 = gr.Button(value='Exit')
|
305 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
306 |
with gr.Accordion(label='Setting', open=False):
|
307 |
with gr.Row():
|
308 |
self.model_name_dd = gr.Dropdown(
|
309 |
choices=self.model_choices,
|
310 |
-
value=self.
|
311 |
label='Model Version')
|
312 |
|
313 |
with gr.Row():
|
@@ -318,39 +353,63 @@ class ChatBotUI(object):
|
|
318 |
label='Negative Prompt',
|
319 |
container=False)
|
320 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
321 |
with gr.Row():
|
322 |
with gr.Column(scale=8, min_width=500):
|
323 |
with gr.Row():
|
324 |
self.step = gr.Slider(minimum=1,
|
325 |
maximum=1000,
|
326 |
-
value=20,
|
|
|
327 |
label='Sample Step')
|
328 |
self.cfg_scale = gr.Slider(
|
329 |
minimum=1.0,
|
330 |
maximum=20.0,
|
331 |
-
value=4.5,
|
|
|
332 |
label='Guidance Scale')
|
333 |
self.rescale = gr.Slider(minimum=0.0,
|
334 |
maximum=1.0,
|
335 |
-
value=0.5,
|
|
|
336 |
label='Rescale')
|
|
|
|
|
|
|
|
|
|
|
337 |
self.seed = gr.Slider(minimum=-1,
|
338 |
maximum=10000000,
|
339 |
value=-1,
|
340 |
label='Seed')
|
341 |
self.output_height = gr.Slider(
|
342 |
minimum=256,
|
343 |
-
maximum=
|
344 |
-
value=
|
|
|
345 |
label='Output Height')
|
346 |
self.output_width = gr.Slider(
|
347 |
minimum=256,
|
348 |
-
maximum=
|
349 |
-
value=
|
|
|
350 |
label='Output Width')
|
351 |
with gr.Column(scale=1, min_width=50):
|
352 |
self.use_history = gr.Checkbox(value=False,
|
353 |
label='Use History')
|
|
|
|
|
|
|
354 |
self.video_auto = gr.Checkbox(
|
355 |
value=False,
|
356 |
label='Auto Gen Video',
|
@@ -387,9 +446,9 @@ class ChatBotUI(object):
|
|
387 |
visible=True)
|
388 |
|
389 |
with gr.Row():
|
390 |
-
|
391 |
**Instruction**:
|
392 |
-
|
393 |
1. Click 'Upload' button to upload one or more images as input images.
|
394 |
2. Enter '@' in the text box will exhibit all images in the gallery.
|
395 |
3. Select the image you wish to edit from the gallery, and its Image ID will be displayed in the text box.
|
@@ -399,14 +458,27 @@ class ChatBotUI(object):
|
|
399 |
6. **Important** To render text on an image, please ensure to include a space between each letter. For instance, "add text 'g i r l' on the mask area of @xxxxx".
|
400 |
7. To implement local editing based on a specified mask, simply click on the image within the chat window to access the image editor. Here, you can draw a mask and then click the 'Submit' button to upload the edited image along with the mask. For inpainting tasks, select the 'Composite' mask type, while for outpainting tasks, choose the 'Outpainting' mask type. For all other local editing tasks, please select the 'Background' mask type.
|
401 |
8. If you find our work valuable, we invite you to refer to the [ACE Page](https://ali-vilab.github.io/ace-page/) for comprehensive information.
|
402 |
-
|
403 |
"""
|
404 |
-
|
405 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
406 |
with gr.Row(variant='panel',
|
407 |
equal_height=True,
|
408 |
show_progress=False):
|
409 |
-
with gr.Column(scale=1, min_width=100):
|
410 |
self.upload_btn = gr.Button(value=upload_sty +
|
411 |
' Upload',
|
412 |
variant='secondary')
|
@@ -416,12 +488,16 @@ class ChatBotUI(object):
|
|
416 |
label='Instruction',
|
417 |
container=False)
|
418 |
with gr.Column(scale=1, min_width=100):
|
419 |
-
self.chat_btn = gr.Button(value=
|
420 |
variant='primary')
|
421 |
with gr.Column(scale=1, min_width=100):
|
422 |
self.retry_btn = gr.Button(value=refresh_sty +
|
423 |
' Retry',
|
424 |
variant='secondary')
|
|
|
|
|
|
|
|
|
425 |
with gr.Column(scale=(1 if self.enable_i2v else 0),
|
426 |
min_width=0):
|
427 |
self.video_gen_btn = gr.Button(value=video_sty +
|
@@ -457,19 +533,77 @@ class ChatBotUI(object):
|
|
457 |
lock.acquire()
|
458 |
del self.pipe
|
459 |
torch.cuda.empty_cache()
|
460 |
-
model_cfg = Config(load=True,
|
461 |
-
cfg_file=self.model_choices[model_name])
|
462 |
self.pipe = ACEInference()
|
463 |
-
self.pipe.init_from_cfg(
|
464 |
self.model_name = model_name
|
465 |
lock.release()
|
466 |
|
467 |
-
return model_name, gr.update(), gr.update()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
468 |
|
469 |
self.model_name_dd.change(
|
470 |
change_model,
|
471 |
inputs=[self.model_name_dd],
|
472 |
-
outputs=[
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
473 |
|
474 |
########################################
|
475 |
def generate_gallery(text, images):
|
@@ -516,7 +650,6 @@ class ChatBotUI(object):
|
|
516 |
outputs=[self.text, self.gallery])
|
517 |
|
518 |
########################################
|
519 |
-
@spaces.GPU(duration=120)
|
520 |
def generate_video(message,
|
521 |
extend_prompt,
|
522 |
history,
|
@@ -527,6 +660,9 @@ class ChatBotUI(object):
|
|
527 |
fps,
|
528 |
seed,
|
529 |
progress=gr.Progress(track_tqdm=True)):
|
|
|
|
|
|
|
530 |
generator = torch.Generator(device='cuda').manual_seed(seed)
|
531 |
img_ids = re.findall('@(.*?)[ ,;.?$]', message)
|
532 |
if len(img_ids) == 0:
|
@@ -598,7 +734,11 @@ class ChatBotUI(object):
|
|
598 |
|
599 |
########################################
|
600 |
@spaces.GPU(duration=60)
|
601 |
-
def run_chat(
|
|
|
|
|
|
|
|
|
602 |
extend_prompt,
|
603 |
history,
|
604 |
images,
|
@@ -607,6 +747,8 @@ class ChatBotUI(object):
|
|
607 |
negative_prompt,
|
608 |
cfg_scale,
|
609 |
rescale,
|
|
|
|
|
610 |
step,
|
611 |
seed,
|
612 |
output_h,
|
@@ -618,12 +760,25 @@ class ChatBotUI(object):
|
|
618 |
video_fps,
|
619 |
video_seed,
|
620 |
progress=gr.Progress(track_tqdm=True)):
|
|
|
|
|
|
|
|
|
|
|
|
|
621 |
retry_msg = message
|
622 |
gen_id = get_md5(message)[:12]
|
623 |
save_path = os.path.join(self.cache_dir, f'{gen_id}.png')
|
624 |
|
625 |
img_ids = re.findall('@(.*?)[ ,;.?$]', message)
|
626 |
history_io = None
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
627 |
new_message = message
|
628 |
|
629 |
if len(img_ids) > 0:
|
@@ -655,9 +810,9 @@ class ChatBotUI(object):
|
|
655 |
history_io = history_result[img_id]
|
656 |
|
657 |
buffered = io.BytesIO()
|
658 |
-
edit_image[0].save(buffered, format='
|
659 |
img_b64 = base64.b64encode(buffered.getvalue()).decode('utf-8')
|
660 |
-
img_str = f'<img src="data:image/
|
661 |
pre_info = f'Received one or more images, so image editing is conducted.\n The first input image @{img_ids[0]} is:\n {img_str}'
|
662 |
else:
|
663 |
pre_info = 'No image ids were found in the provided text prompt, so text-guided image generation is conducted. \n'
|
@@ -682,6 +837,9 @@ class ChatBotUI(object):
|
|
682 |
guide_scale=cfg_scale,
|
683 |
guide_rescale=rescale,
|
684 |
seed=seed,
|
|
|
|
|
|
|
685 |
)
|
686 |
|
687 |
img = imgs[0]
|
@@ -728,9 +886,9 @@ class ChatBotUI(object):
|
|
728 |
}
|
729 |
|
730 |
buffered = io.BytesIO()
|
731 |
-
img.convert('RGB').save(buffered, format='
|
732 |
img_b64 = base64.b64encode(buffered.getvalue()).decode('utf-8')
|
733 |
-
img_str = f'<img src="data:image/
|
734 |
|
735 |
history.append(
|
736 |
(message,
|
@@ -790,21 +948,25 @@ class ChatBotUI(object):
|
|
790 |
while len(history) >= self.max_msgs:
|
791 |
history.pop(0)
|
792 |
|
793 |
-
return history, images,
|
794 |
-
|
795 |
-
|
|
|
796 |
|
797 |
chat_inputs = [
|
|
|
798 |
self.extend_prompt, self.history, self.images, self.use_history,
|
799 |
self.history_result, self.negative_prompt, self.cfg_scale,
|
800 |
-
self.rescale, self.
|
|
|
801 |
self.output_width, self.video_auto, self.video_step,
|
802 |
self.video_frames, self.video_cfg_scale, self.video_fps,
|
803 |
self.video_seed
|
804 |
]
|
805 |
|
806 |
chat_outputs = [
|
807 |
-
self.history, self.images, self.
|
|
|
808 |
self.text, self.gallery, self.retry_msg
|
809 |
]
|
810 |
|
@@ -848,9 +1010,9 @@ class ChatBotUI(object):
|
|
848 |
edit_task.append('')
|
849 |
|
850 |
buffered = io.BytesIO()
|
851 |
-
img.save(buffered, format='
|
852 |
img_b64 = base64.b64encode(buffered.getvalue()).decode('utf-8')
|
853 |
-
img_str = f'<img src="data:image/
|
854 |
pre_info = f'Received one or more images, so image editing is conducted.\n The first input image is:\n {img_str}'
|
855 |
else:
|
856 |
pre_info = 'No image ids were found in the provided text prompt, so text-guided image generation is conducted. \n'
|
@@ -866,13 +1028,15 @@ class ChatBotUI(object):
|
|
866 |
prompt=[prompt] * img_num,
|
867 |
negative_prompt=[''] * img_num,
|
868 |
seed=seed,
|
|
|
|
|
869 |
)
|
870 |
|
871 |
img = imgs[0]
|
872 |
buffered = io.BytesIO()
|
873 |
-
img.convert('RGB').save(buffered, format='
|
874 |
img_b64 = base64.b64encode(buffered.getvalue()).decode('utf-8')
|
875 |
-
img_str = f'<img src="data:image/
|
876 |
history = [(prompt,
|
877 |
f'{pre_info} The generated image is:\n {img_str}')]
|
878 |
return self.get_history(history), gr.update(value=''), gr.update(
|
@@ -911,21 +1075,23 @@ class ChatBotUI(object):
|
|
911 |
return (gr.update(visible=True,
|
912 |
scale=1), gr.update(visible=True, scale=1),
|
913 |
gr.update(visible=True), gr.update(visible=False),
|
914 |
-
gr.update(visible=False), gr.update(visible=False)
|
|
|
915 |
|
916 |
self.upload_btn.click(upload_image,
|
917 |
inputs=[],
|
918 |
outputs=[
|
919 |
self.chat_page, self.editor_page,
|
920 |
self.upload_tab, self.edit_tab,
|
921 |
-
self.image_view_tab, self.video_view_tab
|
|
|
922 |
])
|
923 |
|
924 |
########################################
|
925 |
def edit_image(evt: gr.SelectData):
|
926 |
if isinstance(evt.value, str):
|
927 |
img_b64s = re.findall(
|
928 |
-
'<img src="data:image/
|
929 |
evt.value)
|
930 |
imgs = [
|
931 |
Image.open(io.BytesIO(base64.b64decode(copy.deepcopy(i))))
|
@@ -933,13 +1099,19 @@ class ChatBotUI(object):
|
|
933 |
]
|
934 |
if len(imgs) > 0:
|
935 |
if len(imgs) == 2:
|
936 |
-
|
|
|
|
|
|
|
937 |
edit_img = copy.deepcopy(imgs[-1])
|
938 |
else:
|
939 |
-
|
940 |
-
copy.deepcopy(imgs[-1])
|
941 |
-
|
942 |
-
|
|
|
|
|
|
|
943 |
edit_img = copy.deepcopy(imgs[-1])
|
944 |
|
945 |
return (gr.update(visible=True,
|
@@ -948,11 +1120,12 @@ class ChatBotUI(object):
|
|
948 |
gr.update(visible=False), gr.update(visible=True),
|
949 |
gr.update(visible=True), gr.update(visible=False),
|
950 |
gr.update(value=edit_img),
|
951 |
-
gr.update(value=view_img), gr.update(value=None)
|
|
|
952 |
else:
|
953 |
return (gr.update(), gr.update(), gr.update(), gr.update(),
|
954 |
gr.update(), gr.update(), gr.update(), gr.update(),
|
955 |
-
gr.update())
|
956 |
elif isinstance(evt.value, dict) and evt.value.get(
|
957 |
'component', '') == 'video':
|
958 |
value = evt.value['value']['video']['path']
|
@@ -960,11 +1133,12 @@ class ChatBotUI(object):
|
|
960 |
scale=1), gr.update(visible=True, scale=1),
|
961 |
gr.update(visible=False), gr.update(visible=False),
|
962 |
gr.update(visible=False), gr.update(visible=True),
|
963 |
-
gr.update(), gr.update(), gr.update(value=value)
|
|
|
964 |
else:
|
965 |
return (gr.update(), gr.update(), gr.update(), gr.update(),
|
966 |
gr.update(), gr.update(), gr.update(), gr.update(),
|
967 |
-
gr.update())
|
968 |
|
969 |
self.chatbot.select(edit_image,
|
970 |
outputs=[
|
@@ -972,16 +1146,17 @@ class ChatBotUI(object):
|
|
972 |
self.upload_tab, self.edit_tab,
|
973 |
self.image_view_tab, self.video_view_tab,
|
974 |
self.image_editor, self.image_viewer,
|
975 |
-
self.video_viewer
|
976 |
])
|
977 |
|
978 |
-
self.
|
979 |
-
|
980 |
-
|
|
|
981 |
|
982 |
########################################
|
983 |
def submit_upload_image(image, history, images):
|
984 |
-
history, images = self.add_uploaded_image_to_history(
|
985 |
image, history, images)
|
986 |
return gr.update(visible=False), gr.update(
|
987 |
visible=True), gr.update(
|
@@ -1151,14 +1326,14 @@ class ChatBotUI(object):
|
|
1151 |
thumbnail.save(thumbnail_path, format='JPEG')
|
1152 |
|
1153 |
buffered = io.BytesIO()
|
1154 |
-
img.convert('RGB').save(buffered, format='
|
1155 |
img_b64 = base64.b64encode(buffered.getvalue()).decode('utf-8')
|
1156 |
-
img_str = f'<img src="data:image/
|
1157 |
|
1158 |
buffered = io.BytesIO()
|
1159 |
-
mask.convert('RGB').save(buffered, format='
|
1160 |
mask_b64 = base64.b64encode(buffered.getvalue()).decode('utf-8')
|
1161 |
-
mask_str = f'<img src="data:image/
|
1162 |
|
1163 |
images[img_id] = {
|
1164 |
'image': save_path,
|
@@ -1207,19 +1382,18 @@ class ChatBotUI(object):
|
|
1207 |
}
|
1208 |
|
1209 |
buffered = io.BytesIO()
|
1210 |
-
img.convert('RGB').save(buffered, format='
|
1211 |
img_b64 = base64.b64encode(buffered.getvalue()).decode('utf-8')
|
1212 |
-
img_str = f'<img src="data:image/
|
1213 |
|
1214 |
history.append(
|
1215 |
(None,
|
1216 |
f'This is uploaded image:\n {img_str} image ID is: {img_id}'))
|
1217 |
-
return history, images
|
1218 |
|
1219 |
|
1220 |
|
1221 |
if __name__ == '__main__':
|
1222 |
-
|
1223 |
cfg = Config(cfg_file="config/chatbot_ui.yaml")
|
1224 |
with gr.Blocks() as demo:
|
1225 |
chatbot = ChatBotUI(cfg)
|
|
|
1 |
# -*- coding: utf-8 -*-
|
2 |
# Copyright (c) Alibaba, Inc. and its affiliates.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
3 |
import base64
|
4 |
import copy
|
5 |
import glob
|
6 |
import io
|
7 |
+
import os, csv, sys
|
8 |
import random
|
9 |
import re
|
10 |
+
import shlex
|
11 |
import string
|
12 |
+
import subprocess
|
13 |
import threading
|
14 |
import spaces
|
15 |
+
|
16 |
+
subprocess.run(shlex.split('pip install flash-attn --no-build-isolation'),
|
17 |
+
env=os.environ | {'FLASH_ATTENTION_SKIP_CUDA_BUILD': "TRUE"})
|
18 |
+
|
19 |
import cv2
|
20 |
import gradio as gr
|
21 |
import numpy as np
|
22 |
import torch
|
23 |
import transformers
|
|
|
|
|
|
|
24 |
from PIL import Image
|
25 |
from transformers import AutoModel, AutoTokenizer
|
26 |
|
27 |
+
from scepter.modules.inference.ace_inference import ACEInference
|
28 |
from scepter.modules.utils.config import Config
|
29 |
from scepter.modules.utils.directory import get_md5
|
30 |
from scepter.modules.utils.file_system import FS
|
31 |
from scepter.studio.utils.env import init_env
|
32 |
+
from importlib.metadata import version
|
33 |
|
34 |
+
from .example import get_examples
|
35 |
+
from .utils import load_image
|
|
|
36 |
|
37 |
+
csv.field_size_limit(sys.maxsize)
|
38 |
|
39 |
refresh_sty = '\U0001f504' # 🔄
|
40 |
clear_sty = '\U0001f5d1' # 🗑️
|
|
|
48 |
|
49 |
class ChatBotUI(object):
|
50 |
def __init__(self,
|
51 |
+
cfg_general_file,
|
52 |
+
is_debug=False,
|
53 |
+
language='en',
|
54 |
root_work_dir='./'):
|
55 |
+
try:
|
56 |
+
from diffusers import CogVideoXImageToVideoPipeline
|
57 |
+
from diffusers.utils import export_to_video
|
58 |
+
except Exception as e:
|
59 |
+
print(f"Import diffusers failed, please install or upgrade diffusers. Error information: {e}")
|
60 |
|
61 |
+
cfg = Config(cfg_file=cfg_general_file)
|
62 |
cfg.WORK_DIR = os.path.join(root_work_dir, cfg.WORK_DIR)
|
63 |
if not FS.exists(cfg.WORK_DIR):
|
64 |
FS.make_dir(cfg.WORK_DIR)
|
65 |
cfg = init_env(cfg)
|
66 |
self.cache_dir = cfg.WORK_DIR
|
67 |
+
self.chatbot_examples = get_examples(self.cache_dir) if not cfg.get('SKIP_EXAMPLES', False) else []
|
68 |
self.model_cfg_dir = cfg.MODEL.EDIT_MODEL.MODEL_CFG_DIR
|
69 |
self.model_yamls = glob.glob(os.path.join(self.model_cfg_dir,
|
70 |
'*.yaml'))
|
71 |
self.model_choices = dict()
|
72 |
+
self.default_model_name = ''
|
73 |
for i in self.model_yamls:
|
74 |
+
model_cfg = Config(load=True, cfg_file=i)
|
75 |
+
model_name = model_cfg.NAME
|
76 |
+
if model_cfg.IS_DEFAULT: self.default_model_name = model_name
|
77 |
+
self.model_choices[model_name] = model_cfg
|
78 |
+
print('Models: ', self.model_choices.keys())
|
79 |
+
assert len(self.model_choices) > 0
|
80 |
+
if self.default_model_name == "": self.default_model_name = self.model_choices.keys()[0]
|
81 |
+
self.model_name = self.default_model_name
|
82 |
self.pipe = ACEInference()
|
83 |
+
self.pipe.init_from_cfg(self.model_choices[self.default_model_name])
|
84 |
self.max_msgs = 20
|
|
|
85 |
self.enable_i2v = cfg.get('ENABLE_I2V', False)
|
86 |
+
self.gradio_version = version('gradio')
|
87 |
+
|
88 |
if self.enable_i2v:
|
89 |
self.i2v_model_dir = cfg.MODEL.I2V.MODEL_DIR
|
90 |
self.i2v_model_name = cfg.MODEL.I2V.MODEL_NAME
|
|
|
175 |
]
|
176 |
|
177 |
def create_ui(self):
|
178 |
+
|
179 |
css = '.chatbot.prose.md {opacity: 1.0 !important} #chatbot {opacity: 1.0 !important}'
|
180 |
with gr.Blocks(css=css,
|
181 |
title='Chatbot',
|
|
|
186 |
self.history_result = gr.State(value={})
|
187 |
self.retry_msg = gr.State(value='')
|
188 |
with gr.Group():
|
189 |
+
self.ui_mode = gr.State(value='legacy')
|
190 |
+
with gr.Row(equal_height=True, visible=False) as self.chat_group:
|
191 |
with gr.Column(visible=True) as self.chat_page:
|
192 |
self.chatbot = gr.Chatbot(
|
193 |
height=600,
|
|
|
202 |
size='sm')
|
203 |
|
204 |
with gr.Column(visible=False) as self.editor_page:
|
205 |
+
with gr.Tabs(visible=False) as self.upload_tabs:
|
206 |
with gr.Tab(id='ImageUploader',
|
207 |
label='Image Uploader',
|
208 |
visible=True) as self.upload_tab:
|
|
|
211 |
interactive=True,
|
212 |
type='pil',
|
213 |
image_mode='RGB',
|
214 |
+
sources=['upload'],
|
215 |
elem_id='image_uploader',
|
216 |
format='png')
|
217 |
with gr.Row():
|
|
|
219 |
value='Submit',
|
220 |
elem_id='upload_submit')
|
221 |
self.ext_btn_1 = gr.Button(value='Exit')
|
222 |
+
with gr.Tabs(visible=False) as self.edit_tabs:
|
223 |
with gr.Tab(id='ImageEditor',
|
224 |
+
label='Image Editor') as self.edit_tab:
|
|
|
225 |
self.mask_type = gr.Dropdown(
|
226 |
label='Mask Type',
|
227 |
choices=[
|
|
|
284 |
self.ext_btn_2 = gr.Button(value='Exit')
|
285 |
|
286 |
with gr.Tab(id='ImageViewer',
|
287 |
+
label='Image Viewer') as self.image_view_tab:
|
288 |
+
if self.gradio_version >= '5.0.0':
|
289 |
+
self.image_viewer = gr.Image(
|
290 |
+
label='Image',
|
291 |
+
type='pil',
|
292 |
+
show_download_button=True,
|
293 |
+
elem_id='image_viewer')
|
294 |
+
else:
|
295 |
+
try:
|
296 |
+
from gradio_imageslider import ImageSlider
|
297 |
+
except Exception as e:
|
298 |
+
print(f"Import gradio_imageslider failed, please install.")
|
299 |
+
self.image_viewer = ImageSlider(
|
300 |
+
label='Image',
|
301 |
+
type='pil',
|
302 |
+
show_download_button=True,
|
303 |
+
elem_id='image_viewer')
|
304 |
|
305 |
self.ext_btn_3 = gr.Button(value='Exit')
|
306 |
|
|
|
319 |
|
320 |
self.ext_btn_4 = gr.Button(value='Exit')
|
321 |
|
322 |
+
with gr.Row(equal_height=True, visible=True) as self.legacy_group:
|
323 |
+
with gr.Column():
|
324 |
+
self.legacy_image_uploader = gr.Image(
|
325 |
+
height=550,
|
326 |
+
interactive=True,
|
327 |
+
type='pil',
|
328 |
+
image_mode='RGB',
|
329 |
+
elem_id='legacy_image_uploader',
|
330 |
+
format='png')
|
331 |
+
with gr.Column():
|
332 |
+
self.legacy_image_viewer = gr.Image(
|
333 |
+
label='Image',
|
334 |
+
height=550,
|
335 |
+
type='pil',
|
336 |
+
interactive=False,
|
337 |
+
show_download_button=True,
|
338 |
+
elem_id='image_viewer')
|
339 |
+
|
340 |
+
|
341 |
with gr.Accordion(label='Setting', open=False):
|
342 |
with gr.Row():
|
343 |
self.model_name_dd = gr.Dropdown(
|
344 |
choices=self.model_choices,
|
345 |
+
value=self.default_model_name,
|
346 |
label='Model Version')
|
347 |
|
348 |
with gr.Row():
|
|
|
353 |
label='Negative Prompt',
|
354 |
container=False)
|
355 |
|
356 |
+
with gr.Row():
|
357 |
+
# REFINER_PROMPT
|
358 |
+
self.refiner_prompt = gr.Textbox(
|
359 |
+
value=self.pipe.input.get("refiner_prompt", ""),
|
360 |
+
visible=self.pipe.input.get("refiner_prompt", None) is not None,
|
361 |
+
placeholder=
|
362 |
+
'Prompt used for refiner',
|
363 |
+
label='Refiner Prompt',
|
364 |
+
container=False)
|
365 |
+
|
366 |
+
|
367 |
with gr.Row():
|
368 |
with gr.Column(scale=8, min_width=500):
|
369 |
with gr.Row():
|
370 |
self.step = gr.Slider(minimum=1,
|
371 |
maximum=1000,
|
372 |
+
value=self.pipe.input.get("sample_steps", 20),
|
373 |
+
visible=self.pipe.input.get("sample_steps", None) is not None,
|
374 |
label='Sample Step')
|
375 |
self.cfg_scale = gr.Slider(
|
376 |
minimum=1.0,
|
377 |
maximum=20.0,
|
378 |
+
value=self.pipe.input.get("guide_scale", 4.5),
|
379 |
+
visible=self.pipe.input.get("guide_scale", None) is not None,
|
380 |
label='Guidance Scale')
|
381 |
self.rescale = gr.Slider(minimum=0.0,
|
382 |
maximum=1.0,
|
383 |
+
value=self.pipe.input.get("guide_rescale", 0.5),
|
384 |
+
visible=self.pipe.input.get("guide_rescale", None) is not None,
|
385 |
label='Rescale')
|
386 |
+
self.refiner_scale = gr.Slider(minimum=-0.1,
|
387 |
+
maximum=1.0,
|
388 |
+
value=self.pipe.input.get("refiner_scale", 0.5),
|
389 |
+
visible=self.pipe.input.get("refiner_scale", None) is not None,
|
390 |
+
label='Refiner Scale')
|
391 |
self.seed = gr.Slider(minimum=-1,
|
392 |
maximum=10000000,
|
393 |
value=-1,
|
394 |
label='Seed')
|
395 |
self.output_height = gr.Slider(
|
396 |
minimum=256,
|
397 |
+
maximum=1440,
|
398 |
+
value=self.pipe.input.get("output_height", 1024),
|
399 |
+
visible=self.pipe.input.get("output_height", None) is not None,
|
400 |
label='Output Height')
|
401 |
self.output_width = gr.Slider(
|
402 |
minimum=256,
|
403 |
+
maximum=1440,
|
404 |
+
value=self.pipe.input.get("output_width", 1024),
|
405 |
+
visible=self.pipe.input.get("output_width", None) is not None,
|
406 |
label='Output Width')
|
407 |
with gr.Column(scale=1, min_width=50):
|
408 |
self.use_history = gr.Checkbox(value=False,
|
409 |
label='Use History')
|
410 |
+
self.use_ace = gr.Checkbox(value=self.pipe.input.get("use_ace", True),
|
411 |
+
visible=self.pipe.input.get("use_ace", None) is not None,
|
412 |
+
label='Use ACE')
|
413 |
self.video_auto = gr.Checkbox(
|
414 |
value=False,
|
415 |
label='Auto Gen Video',
|
|
|
446 |
visible=True)
|
447 |
|
448 |
with gr.Row():
|
449 |
+
self.chatbot_inst = """
|
450 |
**Instruction**:
|
451 |
+
|
452 |
1. Click 'Upload' button to upload one or more images as input images.
|
453 |
2. Enter '@' in the text box will exhibit all images in the gallery.
|
454 |
3. Select the image you wish to edit from the gallery, and its Image ID will be displayed in the text box.
|
|
|
458 |
6. **Important** To render text on an image, please ensure to include a space between each letter. For instance, "add text 'g i r l' on the mask area of @xxxxx".
|
459 |
7. To implement local editing based on a specified mask, simply click on the image within the chat window to access the image editor. Here, you can draw a mask and then click the 'Submit' button to upload the edited image along with the mask. For inpainting tasks, select the 'Composite' mask type, while for outpainting tasks, choose the 'Outpainting' mask type. For all other local editing tasks, please select the 'Background' mask type.
|
460 |
8. If you find our work valuable, we invite you to refer to the [ACE Page](https://ali-vilab.github.io/ace-page/) for comprehensive information.
|
461 |
+
|
462 |
"""
|
463 |
+
|
464 |
+
self.legacy_inst = """
|
465 |
+
**Instruction**:
|
466 |
+
|
467 |
+
1. You can edit the image by uploading it; if no image is uploaded, an image will be generated from text..
|
468 |
+
2. Enter '@' in the text box will exhibit all images in the gallery.
|
469 |
+
3. Select the image you wish to edit from the gallery, and its Image ID will be displayed in the text box.
|
470 |
+
4. **Important** To render text on an image, please ensure to include a space between each letter. For instance, "add text 'g i r l' on the mask area of @xxxxx".
|
471 |
+
5. To perform multi-step editing, partial editing, inpainting, outpainting, and other operations, please click the Chatbot Checkbox to enable the conversational editing mode and follow the relevant instructions..
|
472 |
+
6. If you find our work valuable, we invite you to refer to the [ACE Page](https://ali-vilab.github.io/ace-page/) for comprehensive information.
|
473 |
+
|
474 |
+
"""
|
475 |
+
|
476 |
+
self.instruction = gr.Markdown(value=self.legacy_inst)
|
477 |
+
|
478 |
with gr.Row(variant='panel',
|
479 |
equal_height=True,
|
480 |
show_progress=False):
|
481 |
+
with gr.Column(scale=1, min_width=100, visible=False) as self.upload_panel:
|
482 |
self.upload_btn = gr.Button(value=upload_sty +
|
483 |
' Upload',
|
484 |
variant='secondary')
|
|
|
488 |
label='Instruction',
|
489 |
container=False)
|
490 |
with gr.Column(scale=1, min_width=100):
|
491 |
+
self.chat_btn = gr.Button(value='Generate',
|
492 |
variant='primary')
|
493 |
with gr.Column(scale=1, min_width=100):
|
494 |
self.retry_btn = gr.Button(value=refresh_sty +
|
495 |
' Retry',
|
496 |
variant='secondary')
|
497 |
+
with gr.Column(scale=1, min_width=100):
|
498 |
+
self.mode_checkbox = gr.Checkbox(
|
499 |
+
value=False,
|
500 |
+
label='ChatBot')
|
501 |
with gr.Column(scale=(1 if self.enable_i2v else 0),
|
502 |
min_width=0):
|
503 |
self.video_gen_btn = gr.Button(value=video_sty +
|
|
|
533 |
lock.acquire()
|
534 |
del self.pipe
|
535 |
torch.cuda.empty_cache()
|
|
|
|
|
536 |
self.pipe = ACEInference()
|
537 |
+
self.pipe.init_from_cfg(self.model_choices[model_name])
|
538 |
self.model_name = model_name
|
539 |
lock.release()
|
540 |
|
541 |
+
return (model_name, gr.update(), gr.update(),
|
542 |
+
gr.Slider(
|
543 |
+
value=self.pipe.input.get("sample_steps", 20),
|
544 |
+
visible=self.pipe.input.get("sample_steps", None) is not None),
|
545 |
+
gr.Slider(
|
546 |
+
value=self.pipe.input.get("guide_scale", 4.5),
|
547 |
+
visible=self.pipe.input.get("guide_scale", None) is not None),
|
548 |
+
gr.Slider(
|
549 |
+
value=self.pipe.input.get("guide_rescale", 0.5),
|
550 |
+
visible=self.pipe.input.get("guide_rescale", None) is not None),
|
551 |
+
gr.Slider(
|
552 |
+
value=self.pipe.input.get("output_height", 1024),
|
553 |
+
visible=self.pipe.input.get("output_height", None) is not None),
|
554 |
+
gr.Slider(
|
555 |
+
value=self.pipe.input.get("output_width", 1024),
|
556 |
+
visible=self.pipe.input.get("output_width", None) is not None),
|
557 |
+
gr.Textbox(
|
558 |
+
value=self.pipe.input.get("refiner_prompt", ""),
|
559 |
+
visible=self.pipe.input.get("refiner_prompt", None) is not None),
|
560 |
+
gr.Slider(
|
561 |
+
value=self.pipe.input.get("refiner_scale", 0.5),
|
562 |
+
visible=self.pipe.input.get("refiner_scale", None) is not None
|
563 |
+
),
|
564 |
+
gr.Checkbox(
|
565 |
+
value=self.pipe.input.get("use_ace", True),
|
566 |
+
visible=self.pipe.input.get("use_ace", None) is not None
|
567 |
+
)
|
568 |
+
)
|
569 |
|
570 |
self.model_name_dd.change(
|
571 |
change_model,
|
572 |
inputs=[self.model_name_dd],
|
573 |
+
outputs=[
|
574 |
+
self.model_name_dd, self.chatbot, self.text,
|
575 |
+
self.step,
|
576 |
+
self.cfg_scale, self.rescale, self.output_height,
|
577 |
+
self.output_width, self.refiner_prompt, self.refiner_scale,
|
578 |
+
self.use_ace])
|
579 |
+
|
580 |
+
|
581 |
+
def mode_change(mode_check):
|
582 |
+
if mode_check:
|
583 |
+
# ChatBot
|
584 |
+
return (
|
585 |
+
gr.Row(visible=False),
|
586 |
+
gr.Row(visible=True),
|
587 |
+
gr.Button(value='Generate'),
|
588 |
+
gr.State(value='chatbot'),
|
589 |
+
gr.Column(visible=True),
|
590 |
+
gr.Markdown(value=self.chatbot_inst)
|
591 |
+
)
|
592 |
+
else:
|
593 |
+
# Legacy
|
594 |
+
return (
|
595 |
+
gr.Row(visible=True),
|
596 |
+
gr.Row(visible=False),
|
597 |
+
gr.Button(value=chat_sty + ' Chat'),
|
598 |
+
gr.State(value='legacy'),
|
599 |
+
gr.Column(visible=False),
|
600 |
+
gr.Markdown(value=self.legacy_inst)
|
601 |
+
)
|
602 |
+
self.mode_checkbox.change(mode_change, inputs=[self.mode_checkbox],
|
603 |
+
outputs=[self.legacy_group, self.chat_group,
|
604 |
+
self.chat_btn, self.ui_mode,
|
605 |
+
self.upload_panel, self.instruction])
|
606 |
+
|
607 |
|
608 |
########################################
|
609 |
def generate_gallery(text, images):
|
|
|
650 |
outputs=[self.text, self.gallery])
|
651 |
|
652 |
########################################
|
|
|
653 |
def generate_video(message,
|
654 |
extend_prompt,
|
655 |
history,
|
|
|
660 |
fps,
|
661 |
seed,
|
662 |
progress=gr.Progress(track_tqdm=True)):
|
663 |
+
|
664 |
+
from diffusers.utils import export_to_video
|
665 |
+
|
666 |
generator = torch.Generator(device='cuda').manual_seed(seed)
|
667 |
img_ids = re.findall('@(.*?)[ ,;.?$]', message)
|
668 |
if len(img_ids) == 0:
|
|
|
734 |
|
735 |
########################################
|
736 |
@spaces.GPU(duration=60)
|
737 |
+
def run_chat(
|
738 |
+
message,
|
739 |
+
legacy_image,
|
740 |
+
ui_mode,
|
741 |
+
use_ace,
|
742 |
extend_prompt,
|
743 |
history,
|
744 |
images,
|
|
|
747 |
negative_prompt,
|
748 |
cfg_scale,
|
749 |
rescale,
|
750 |
+
refiner_prompt,
|
751 |
+
refiner_scale,
|
752 |
step,
|
753 |
seed,
|
754 |
output_h,
|
|
|
760 |
video_fps,
|
761 |
video_seed,
|
762 |
progress=gr.Progress(track_tqdm=True)):
|
763 |
+
legacy_img_ids = []
|
764 |
+
if ui_mode == 'legacy':
|
765 |
+
if legacy_image is not None:
|
766 |
+
history, images, img_id = self.add_uploaded_image_to_history(
|
767 |
+
legacy_image, history, images)
|
768 |
+
legacy_img_ids.append(img_id)
|
769 |
retry_msg = message
|
770 |
gen_id = get_md5(message)[:12]
|
771 |
save_path = os.path.join(self.cache_dir, f'{gen_id}.png')
|
772 |
|
773 |
img_ids = re.findall('@(.*?)[ ,;.?$]', message)
|
774 |
history_io = None
|
775 |
+
|
776 |
+
if len(img_ids) < 1:
|
777 |
+
img_ids = legacy_img_ids
|
778 |
+
for img_id in img_ids:
|
779 |
+
if f'@{img_id}' not in message:
|
780 |
+
message = f'@{img_id} ' + message
|
781 |
+
|
782 |
new_message = message
|
783 |
|
784 |
if len(img_ids) > 0:
|
|
|
810 |
history_io = history_result[img_id]
|
811 |
|
812 |
buffered = io.BytesIO()
|
813 |
+
edit_image[0].save(buffered, format='PNG')
|
814 |
img_b64 = base64.b64encode(buffered.getvalue()).decode('utf-8')
|
815 |
+
img_str = f'<img src="data:image/png;base64,{img_b64}" style="pointer-events: none;">'
|
816 |
pre_info = f'Received one or more images, so image editing is conducted.\n The first input image @{img_ids[0]} is:\n {img_str}'
|
817 |
else:
|
818 |
pre_info = 'No image ids were found in the provided text prompt, so text-guided image generation is conducted. \n'
|
|
|
837 |
guide_scale=cfg_scale,
|
838 |
guide_rescale=rescale,
|
839 |
seed=seed,
|
840 |
+
refiner_prompt=refiner_prompt,
|
841 |
+
refiner_scale=refiner_scale,
|
842 |
+
use_ace=use_ace
|
843 |
)
|
844 |
|
845 |
img = imgs[0]
|
|
|
886 |
}
|
887 |
|
888 |
buffered = io.BytesIO()
|
889 |
+
img.convert('RGB').save(buffered, format='PNG')
|
890 |
img_b64 = base64.b64encode(buffered.getvalue()).decode('utf-8')
|
891 |
+
img_str = f'<img src="data:image/png;base64,{img_b64}" style="pointer-events: none;">'
|
892 |
|
893 |
history.append(
|
894 |
(message,
|
|
|
948 |
while len(history) >= self.max_msgs:
|
949 |
history.pop(0)
|
950 |
|
951 |
+
return (history, images, gr.Image(value=save_path),
|
952 |
+
history_result, self.get_history(
|
953 |
+
history), gr.update(), gr.update(
|
954 |
+
visible=False), retry_msg)
|
955 |
|
956 |
chat_inputs = [
|
957 |
+
self.legacy_image_uploader, self.ui_mode, self.use_ace,
|
958 |
self.extend_prompt, self.history, self.images, self.use_history,
|
959 |
self.history_result, self.negative_prompt, self.cfg_scale,
|
960 |
+
self.rescale, self.refiner_prompt, self.refiner_scale,
|
961 |
+
self.step, self.seed, self.output_height,
|
962 |
self.output_width, self.video_auto, self.video_step,
|
963 |
self.video_frames, self.video_cfg_scale, self.video_fps,
|
964 |
self.video_seed
|
965 |
]
|
966 |
|
967 |
chat_outputs = [
|
968 |
+
self.history, self.images, self.legacy_image_viewer,
|
969 |
+
self.history_result, self.chatbot,
|
970 |
self.text, self.gallery, self.retry_msg
|
971 |
]
|
972 |
|
|
|
1010 |
edit_task.append('')
|
1011 |
|
1012 |
buffered = io.BytesIO()
|
1013 |
+
img.save(buffered, format='PNG')
|
1014 |
img_b64 = base64.b64encode(buffered.getvalue()).decode('utf-8')
|
1015 |
+
img_str = f'<img src="data:image/png;base64,{img_b64}" style="pointer-events: none;">'
|
1016 |
pre_info = f'Received one or more images, so image editing is conducted.\n The first input image is:\n {img_str}'
|
1017 |
else:
|
1018 |
pre_info = 'No image ids were found in the provided text prompt, so text-guided image generation is conducted. \n'
|
|
|
1028 |
prompt=[prompt] * img_num,
|
1029 |
negative_prompt=[''] * img_num,
|
1030 |
seed=seed,
|
1031 |
+
refiner_prompt=self.pipe.input.get("refiner_prompt", ""),
|
1032 |
+
refiner_scale=self.pipe.input.get("refiner_scale", 0.0),
|
1033 |
)
|
1034 |
|
1035 |
img = imgs[0]
|
1036 |
buffered = io.BytesIO()
|
1037 |
+
img.convert('RGB').save(buffered, format='PNG')
|
1038 |
img_b64 = base64.b64encode(buffered.getvalue()).decode('utf-8')
|
1039 |
+
img_str = f'<img src="data:image/png;base64,{img_b64}" style="pointer-events: none;">'
|
1040 |
history = [(prompt,
|
1041 |
f'{pre_info} The generated image is:\n {img_str}')]
|
1042 |
return self.get_history(history), gr.update(value=''), gr.update(
|
|
|
1075 |
return (gr.update(visible=True,
|
1076 |
scale=1), gr.update(visible=True, scale=1),
|
1077 |
gr.update(visible=True), gr.update(visible=False),
|
1078 |
+
gr.update(visible=False), gr.update(visible=False),
|
1079 |
+
gr.update(visible=True))
|
1080 |
|
1081 |
self.upload_btn.click(upload_image,
|
1082 |
inputs=[],
|
1083 |
outputs=[
|
1084 |
self.chat_page, self.editor_page,
|
1085 |
self.upload_tab, self.edit_tab,
|
1086 |
+
self.image_view_tab, self.video_view_tab,
|
1087 |
+
self.upload_tabs
|
1088 |
])
|
1089 |
|
1090 |
########################################
|
1091 |
def edit_image(evt: gr.SelectData):
|
1092 |
if isinstance(evt.value, str):
|
1093 |
img_b64s = re.findall(
|
1094 |
+
'<img src="data:image/png;base64,(.*?)" style="pointer-events: none;">',
|
1095 |
evt.value)
|
1096 |
imgs = [
|
1097 |
Image.open(io.BytesIO(base64.b64decode(copy.deepcopy(i))))
|
|
|
1099 |
]
|
1100 |
if len(imgs) > 0:
|
1101 |
if len(imgs) == 2:
|
1102 |
+
if self.gradio_version >= '5.0.0':
|
1103 |
+
view_img = copy.deepcopy(imgs[-1])
|
1104 |
+
else:
|
1105 |
+
view_img = copy.deepcopy(imgs)
|
1106 |
edit_img = copy.deepcopy(imgs[-1])
|
1107 |
else:
|
1108 |
+
if self.gradio_version >= '5.0.0':
|
1109 |
+
view_img = copy.deepcopy(imgs[-1])
|
1110 |
+
else:
|
1111 |
+
view_img = [
|
1112 |
+
copy.deepcopy(imgs[-1]),
|
1113 |
+
copy.deepcopy(imgs[-1])
|
1114 |
+
]
|
1115 |
edit_img = copy.deepcopy(imgs[-1])
|
1116 |
|
1117 |
return (gr.update(visible=True,
|
|
|
1120 |
gr.update(visible=False), gr.update(visible=True),
|
1121 |
gr.update(visible=True), gr.update(visible=False),
|
1122 |
gr.update(value=edit_img),
|
1123 |
+
gr.update(value=view_img), gr.update(value=None),
|
1124 |
+
gr.update(visible=True))
|
1125 |
else:
|
1126 |
return (gr.update(), gr.update(), gr.update(), gr.update(),
|
1127 |
gr.update(), gr.update(), gr.update(), gr.update(),
|
1128 |
+
gr.update(), gr.update())
|
1129 |
elif isinstance(evt.value, dict) and evt.value.get(
|
1130 |
'component', '') == 'video':
|
1131 |
value = evt.value['value']['video']['path']
|
|
|
1133 |
scale=1), gr.update(visible=True, scale=1),
|
1134 |
gr.update(visible=False), gr.update(visible=False),
|
1135 |
gr.update(visible=False), gr.update(visible=True),
|
1136 |
+
gr.update(), gr.update(), gr.update(value=value),
|
1137 |
+
gr.update())
|
1138 |
else:
|
1139 |
return (gr.update(), gr.update(), gr.update(), gr.update(),
|
1140 |
gr.update(), gr.update(), gr.update(), gr.update(),
|
1141 |
+
gr.update(), gr.update())
|
1142 |
|
1143 |
self.chatbot.select(edit_image,
|
1144 |
outputs=[
|
|
|
1146 |
self.upload_tab, self.edit_tab,
|
1147 |
self.image_view_tab, self.video_view_tab,
|
1148 |
self.image_editor, self.image_viewer,
|
1149 |
+
self.video_viewer, self.edit_tabs
|
1150 |
])
|
1151 |
|
1152 |
+
if self.gradio_version < '5.0.0':
|
1153 |
+
self.image_viewer.change(lambda x: x,
|
1154 |
+
inputs=self.image_viewer,
|
1155 |
+
outputs=self.image_viewer)
|
1156 |
|
1157 |
########################################
|
1158 |
def submit_upload_image(image, history, images):
|
1159 |
+
history, images, _ = self.add_uploaded_image_to_history(
|
1160 |
image, history, images)
|
1161 |
return gr.update(visible=False), gr.update(
|
1162 |
visible=True), gr.update(
|
|
|
1326 |
thumbnail.save(thumbnail_path, format='JPEG')
|
1327 |
|
1328 |
buffered = io.BytesIO()
|
1329 |
+
img.convert('RGB').save(buffered, format='PNG')
|
1330 |
img_b64 = base64.b64encode(buffered.getvalue()).decode('utf-8')
|
1331 |
+
img_str = f'<img src="data:image/png;base64,{img_b64}" style="pointer-events: none;">'
|
1332 |
|
1333 |
buffered = io.BytesIO()
|
1334 |
+
mask.convert('RGB').save(buffered, format='PNG')
|
1335 |
mask_b64 = base64.b64encode(buffered.getvalue()).decode('utf-8')
|
1336 |
+
mask_str = f'<img src="data:image/png;base64,{mask_b64}" style="pointer-events: none;">'
|
1337 |
|
1338 |
images[img_id] = {
|
1339 |
'image': save_path,
|
|
|
1382 |
}
|
1383 |
|
1384 |
buffered = io.BytesIO()
|
1385 |
+
img.convert('RGB').save(buffered, format='PNG')
|
1386 |
img_b64 = base64.b64encode(buffered.getvalue()).decode('utf-8')
|
1387 |
+
img_str = f'<img src="data:image/png;base64,{img_b64}" style="pointer-events: none;">'
|
1388 |
|
1389 |
history.append(
|
1390 |
(None,
|
1391 |
f'This is uploaded image:\n {img_str} image ID is: {img_id}'))
|
1392 |
+
return history, images, img_id
|
1393 |
|
1394 |
|
1395 |
|
1396 |
if __name__ == '__main__':
|
|
|
1397 |
cfg = Config(cfg_file="config/chatbot_ui.yaml")
|
1398 |
with gr.Blocks() as demo:
|
1399 |
chatbot = ChatBotUI(cfg)
|
utils.py
ADDED
@@ -0,0 +1,95 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#copyright (c) Alibaba, Inc. and its affiliates.
|
2 |
+
import torch
|
3 |
+
import torchvision.transforms as T
|
4 |
+
from PIL import Image
|
5 |
+
from torchvision.transforms.functional import InterpolationMode
|
6 |
+
|
7 |
+
IMAGENET_MEAN = (0.485, 0.456, 0.406)
|
8 |
+
IMAGENET_STD = (0.229, 0.224, 0.225)
|
9 |
+
|
10 |
+
|
11 |
+
def build_transform(input_size):
|
12 |
+
MEAN, STD = IMAGENET_MEAN, IMAGENET_STD
|
13 |
+
transform = T.Compose([
|
14 |
+
T.Lambda(lambda img: img.convert('RGB') if img.mode != 'RGB' else img),
|
15 |
+
T.Resize((input_size, input_size),
|
16 |
+
interpolation=InterpolationMode.BICUBIC),
|
17 |
+
T.ToTensor(),
|
18 |
+
T.Normalize(mean=MEAN, std=STD)
|
19 |
+
])
|
20 |
+
return transform
|
21 |
+
|
22 |
+
|
23 |
+
def find_closest_aspect_ratio(aspect_ratio, target_ratios, width, height,
|
24 |
+
image_size):
|
25 |
+
best_ratio_diff = float('inf')
|
26 |
+
best_ratio = (1, 1)
|
27 |
+
area = width * height
|
28 |
+
for ratio in target_ratios:
|
29 |
+
target_aspect_ratio = ratio[0] / ratio[1]
|
30 |
+
ratio_diff = abs(aspect_ratio - target_aspect_ratio)
|
31 |
+
if ratio_diff < best_ratio_diff:
|
32 |
+
best_ratio_diff = ratio_diff
|
33 |
+
best_ratio = ratio
|
34 |
+
elif ratio_diff == best_ratio_diff:
|
35 |
+
if area > 0.5 * image_size * image_size * ratio[0] * ratio[1]:
|
36 |
+
best_ratio = ratio
|
37 |
+
return best_ratio
|
38 |
+
|
39 |
+
|
40 |
+
def dynamic_preprocess(image,
|
41 |
+
min_num=1,
|
42 |
+
max_num=12,
|
43 |
+
image_size=448,
|
44 |
+
use_thumbnail=False):
|
45 |
+
orig_width, orig_height = image.size
|
46 |
+
aspect_ratio = orig_width / orig_height
|
47 |
+
|
48 |
+
# calculate the existing image aspect ratio
|
49 |
+
target_ratios = set((i, j) for n in range(min_num, max_num + 1)
|
50 |
+
for i in range(1, n + 1) for j in range(1, n + 1)
|
51 |
+
if i * j <= max_num and i * j >= min_num)
|
52 |
+
target_ratios = sorted(target_ratios, key=lambda x: x[0] * x[1])
|
53 |
+
|
54 |
+
# find the closest aspect ratio to the target
|
55 |
+
target_aspect_ratio = find_closest_aspect_ratio(aspect_ratio,
|
56 |
+
target_ratios, orig_width,
|
57 |
+
orig_height, image_size)
|
58 |
+
|
59 |
+
# calculate the target width and height
|
60 |
+
target_width = image_size * target_aspect_ratio[0]
|
61 |
+
target_height = image_size * target_aspect_ratio[1]
|
62 |
+
blocks = target_aspect_ratio[0] * target_aspect_ratio[1]
|
63 |
+
|
64 |
+
# resize the image
|
65 |
+
resized_img = image.resize((target_width, target_height))
|
66 |
+
processed_images = []
|
67 |
+
for i in range(blocks):
|
68 |
+
box = ((i % (target_width // image_size)) * image_size,
|
69 |
+
(i // (target_width // image_size)) * image_size,
|
70 |
+
((i % (target_width // image_size)) + 1) * image_size,
|
71 |
+
((i // (target_width // image_size)) + 1) * image_size)
|
72 |
+
# split the image
|
73 |
+
split_img = resized_img.crop(box)
|
74 |
+
processed_images.append(split_img)
|
75 |
+
assert len(processed_images) == blocks
|
76 |
+
if use_thumbnail and len(processed_images) != 1:
|
77 |
+
thumbnail_img = image.resize((image_size, image_size))
|
78 |
+
processed_images.append(thumbnail_img)
|
79 |
+
return processed_images
|
80 |
+
|
81 |
+
|
82 |
+
def load_image(image_file, input_size=448, max_num=12):
|
83 |
+
if isinstance(image_file, str):
|
84 |
+
image = Image.open(image_file).convert('RGB')
|
85 |
+
else:
|
86 |
+
image = image_file
|
87 |
+
transform = build_transform(input_size=input_size)
|
88 |
+
images = dynamic_preprocess(image,
|
89 |
+
image_size=input_size,
|
90 |
+
use_thumbnail=True,
|
91 |
+
max_num=max_num)
|
92 |
+
pixel_values = [transform(image) for image in images]
|
93 |
+
pixel_values = torch.stack(pixel_values)
|
94 |
+
return pixel_values
|
95 |
+
|