Spaces:
Running
on
Zero
Running
on
Zero
chaojiemao
commited on
Commit
•
ec9288d
1
Parent(s):
78e9f55
Update app.py
Browse files
app.py
CHANGED
@@ -1,45 +1,40 @@
|
|
1 |
# -*- coding: utf-8 -*-
|
2 |
# Copyright (c) Alibaba, Inc. and its affiliates.
|
3 |
-
import os
|
4 |
-
import shlex
|
5 |
-
import subprocess
|
6 |
-
subprocess.run(shlex.split('pip install flash-attn --no-build-isolation'), env=os.environ | {'FLASH_ATTENTION_SKIP_CUDA_BUILD': "TRUE"})
|
7 |
-
|
8 |
-
import sys
|
9 |
-
import csv
|
10 |
-
csv.field_size_limit(sys.maxsize)
|
11 |
-
|
12 |
-
import argparse
|
13 |
import base64
|
14 |
import copy
|
15 |
import glob
|
16 |
import io
|
17 |
-
import os
|
18 |
import random
|
19 |
import re
|
|
|
20 |
import string
|
|
|
21 |
import threading
|
22 |
import spaces
|
|
|
|
|
|
|
|
|
23 |
import cv2
|
24 |
import gradio as gr
|
25 |
import numpy as np
|
26 |
import torch
|
27 |
import transformers
|
28 |
-
from diffusers import CogVideoXImageToVideoPipeline
|
29 |
-
from diffusers.utils import export_to_video
|
30 |
-
from gradio_imageslider import ImageSlider
|
31 |
from PIL import Image
|
32 |
from transformers import AutoModel, AutoTokenizer
|
33 |
|
|
|
34 |
from scepter.modules.utils.config import Config
|
35 |
from scepter.modules.utils.directory import get_md5
|
36 |
from scepter.modules.utils.file_system import FS
|
37 |
from scepter.studio.utils.env import init_env
|
|
|
38 |
|
39 |
-
from infer import ACEInference
|
40 |
from example import get_examples
|
41 |
from utils import load_image
|
42 |
|
|
|
43 |
|
44 |
refresh_sty = '\U0001f504' # 🔄
|
45 |
clear_sty = '\U0001f5d1' # 🗑️
|
@@ -53,33 +48,56 @@ lock = threading.Lock()
|
|
53 |
|
54 |
class ChatBotUI(object):
|
55 |
def __init__(self,
|
56 |
-
|
|
|
|
|
57 |
root_work_dir='./'):
|
58 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
59 |
cfg.WORK_DIR = os.path.join(root_work_dir, cfg.WORK_DIR)
|
60 |
if not FS.exists(cfg.WORK_DIR):
|
61 |
FS.make_dir(cfg.WORK_DIR)
|
62 |
cfg = init_env(cfg)
|
63 |
self.cache_dir = cfg.WORK_DIR
|
64 |
-
self.chatbot_examples = get_examples(self.cache_dir)
|
65 |
self.model_cfg_dir = cfg.MODEL.EDIT_MODEL.MODEL_CFG_DIR
|
66 |
self.model_yamls = glob.glob(os.path.join(self.model_cfg_dir,
|
67 |
'*.yaml'))
|
68 |
self.model_choices = dict()
|
|
|
69 |
for i in self.model_yamls:
|
70 |
-
|
71 |
-
|
72 |
-
|
73 |
-
|
74 |
-
|
75 |
-
|
76 |
-
|
77 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
78 |
self.pipe = ACEInference()
|
79 |
-
self.pipe.init_from_cfg(
|
80 |
self.max_msgs = 20
|
81 |
-
|
82 |
self.enable_i2v = cfg.get('ENABLE_I2V', False)
|
|
|
|
|
83 |
if self.enable_i2v:
|
84 |
self.i2v_model_dir = cfg.MODEL.I2V.MODEL_DIR
|
85 |
self.i2v_model_name = cfg.MODEL.I2V.MODEL_NAME
|
@@ -115,15 +133,11 @@ class ChatBotUI(object):
|
|
115 |
)
|
116 |
|
117 |
sys_prompt = """You are part of a team of bots that creates videos. You work with an assistant bot that will draw anything you say in square brackets.
|
118 |
-
|
119 |
For example , outputting " a beautiful morning in the woods with the sun peaking through the trees " will trigger your partner bot to output an video of a forest morning , as described. You will be prompted by people looking to create detailed , amazing videos. The way to accomplish this is to take their short prompts and make them extremely detailed and descriptive.
|
120 |
There are a few rules to follow:
|
121 |
-
|
122 |
You will only ever output a single video description per user request.
|
123 |
-
|
124 |
When modifications are requested , you should not simply make the description longer . You should refactor the entire description to integrate the suggestions.
|
125 |
Other times the user will not want modifications , but instead want a new image . In this case , you should ignore your previous conversation with the user.
|
126 |
-
|
127 |
Video descriptions must have the same num of words as examples below. Extra words will be ignored.
|
128 |
"""
|
129 |
self.enhance_ctx = [
|
@@ -170,6 +184,7 @@ class ChatBotUI(object):
|
|
170 |
]
|
171 |
|
172 |
def create_ui(self):
|
|
|
173 |
css = '.chatbot.prose.md {opacity: 1.0 !important} #chatbot {opacity: 1.0 !important}'
|
174 |
with gr.Blocks(css=css,
|
175 |
title='Chatbot',
|
@@ -180,7 +195,8 @@ class ChatBotUI(object):
|
|
180 |
self.history_result = gr.State(value={})
|
181 |
self.retry_msg = gr.State(value='')
|
182 |
with gr.Group():
|
183 |
-
|
|
|
184 |
with gr.Column(visible=True) as self.chat_page:
|
185 |
self.chatbot = gr.Chatbot(
|
186 |
height=600,
|
@@ -195,7 +211,7 @@ class ChatBotUI(object):
|
|
195 |
size='sm')
|
196 |
|
197 |
with gr.Column(visible=False) as self.editor_page:
|
198 |
-
with gr.Tabs():
|
199 |
with gr.Tab(id='ImageUploader',
|
200 |
label='Image Uploader',
|
201 |
visible=True) as self.upload_tab:
|
@@ -204,7 +220,7 @@ class ChatBotUI(object):
|
|
204 |
interactive=True,
|
205 |
type='pil',
|
206 |
image_mode='RGB',
|
207 |
-
sources='upload',
|
208 |
elem_id='image_uploader',
|
209 |
format='png')
|
210 |
with gr.Row():
|
@@ -212,10 +228,9 @@ class ChatBotUI(object):
|
|
212 |
value='Submit',
|
213 |
elem_id='upload_submit')
|
214 |
self.ext_btn_1 = gr.Button(value='Exit')
|
215 |
-
|
216 |
with gr.Tab(id='ImageEditor',
|
217 |
-
label='Image Editor'
|
218 |
-
visible=False) as self.edit_tab:
|
219 |
self.mask_type = gr.Dropdown(
|
220 |
label='Mask Type',
|
221 |
choices=[
|
@@ -278,13 +293,23 @@ class ChatBotUI(object):
|
|
278 |
self.ext_btn_2 = gr.Button(value='Exit')
|
279 |
|
280 |
with gr.Tab(id='ImageViewer',
|
281 |
-
label='Image Viewer'
|
282 |
-
|
283 |
-
|
284 |
-
|
285 |
-
|
286 |
-
|
287 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
288 |
|
289 |
self.ext_btn_3 = gr.Button(value='Exit')
|
290 |
|
@@ -303,11 +328,30 @@ class ChatBotUI(object):
|
|
303 |
|
304 |
self.ext_btn_4 = gr.Button(value='Exit')
|
305 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
306 |
with gr.Accordion(label='Setting', open=False):
|
307 |
with gr.Row():
|
308 |
self.model_name_dd = gr.Dropdown(
|
309 |
choices=self.model_choices,
|
310 |
-
value=self.
|
311 |
label='Model Version')
|
312 |
|
313 |
with gr.Row():
|
@@ -318,39 +362,63 @@ class ChatBotUI(object):
|
|
318 |
label='Negative Prompt',
|
319 |
container=False)
|
320 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
321 |
with gr.Row():
|
322 |
with gr.Column(scale=8, min_width=500):
|
323 |
with gr.Row():
|
324 |
self.step = gr.Slider(minimum=1,
|
325 |
maximum=1000,
|
326 |
-
value=20,
|
|
|
327 |
label='Sample Step')
|
328 |
self.cfg_scale = gr.Slider(
|
329 |
minimum=1.0,
|
330 |
maximum=20.0,
|
331 |
-
value=4.5,
|
|
|
332 |
label='Guidance Scale')
|
333 |
self.rescale = gr.Slider(minimum=0.0,
|
334 |
maximum=1.0,
|
335 |
-
value=0.5,
|
|
|
336 |
label='Rescale')
|
|
|
|
|
|
|
|
|
|
|
337 |
self.seed = gr.Slider(minimum=-1,
|
338 |
maximum=10000000,
|
339 |
value=-1,
|
340 |
label='Seed')
|
341 |
self.output_height = gr.Slider(
|
342 |
minimum=256,
|
343 |
-
maximum=
|
344 |
-
value=
|
|
|
345 |
label='Output Height')
|
346 |
self.output_width = gr.Slider(
|
347 |
minimum=256,
|
348 |
-
maximum=
|
349 |
-
value=
|
|
|
350 |
label='Output Width')
|
351 |
with gr.Column(scale=1, min_width=50):
|
352 |
self.use_history = gr.Checkbox(value=False,
|
353 |
label='Use History')
|
|
|
|
|
|
|
354 |
self.video_auto = gr.Checkbox(
|
355 |
value=False,
|
356 |
label='Auto Gen Video',
|
@@ -387,9 +455,8 @@ class ChatBotUI(object):
|
|
387 |
visible=True)
|
388 |
|
389 |
with gr.Row():
|
390 |
-
|
391 |
**Instruction**:
|
392 |
-
|
393 |
1. Click 'Upload' button to upload one or more images as input images.
|
394 |
2. Enter '@' in the text box will exhibit all images in the gallery.
|
395 |
3. Select the image you wish to edit from the gallery, and its Image ID will be displayed in the text box.
|
@@ -399,14 +466,24 @@ class ChatBotUI(object):
|
|
399 |
6. **Important** To render text on an image, please ensure to include a space between each letter. For instance, "add text 'g i r l' on the mask area of @xxxxx".
|
400 |
7. To implement local editing based on a specified mask, simply click on the image within the chat window to access the image editor. Here, you can draw a mask and then click the 'Submit' button to upload the edited image along with the mask. For inpainting tasks, select the 'Composite' mask type, while for outpainting tasks, choose the 'Outpainting' mask type. For all other local editing tasks, please select the 'Background' mask type.
|
401 |
8. If you find our work valuable, we invite you to refer to the [ACE Page](https://ali-vilab.github.io/ace-page/) for comprehensive information.
|
402 |
-
|
403 |
"""
|
404 |
-
|
405 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
406 |
with gr.Row(variant='panel',
|
407 |
equal_height=True,
|
408 |
show_progress=False):
|
409 |
-
with gr.Column(scale=1, min_width=100):
|
410 |
self.upload_btn = gr.Button(value=upload_sty +
|
411 |
' Upload',
|
412 |
variant='secondary')
|
@@ -416,12 +493,16 @@ class ChatBotUI(object):
|
|
416 |
label='Instruction',
|
417 |
container=False)
|
418 |
with gr.Column(scale=1, min_width=100):
|
419 |
-
self.chat_btn = gr.Button(value=
|
420 |
variant='primary')
|
421 |
with gr.Column(scale=1, min_width=100):
|
422 |
self.retry_btn = gr.Button(value=refresh_sty +
|
423 |
' Retry',
|
424 |
variant='secondary')
|
|
|
|
|
|
|
|
|
425 |
with gr.Column(scale=(1 if self.enable_i2v else 0),
|
426 |
min_width=0):
|
427 |
self.video_gen_btn = gr.Button(value=video_sty +
|
@@ -457,19 +538,77 @@ class ChatBotUI(object):
|
|
457 |
lock.acquire()
|
458 |
del self.pipe
|
459 |
torch.cuda.empty_cache()
|
460 |
-
model_cfg = Config(load=True,
|
461 |
-
cfg_file=self.model_choices[model_name])
|
462 |
self.pipe = ACEInference()
|
463 |
-
self.pipe.init_from_cfg(
|
464 |
self.model_name = model_name
|
465 |
lock.release()
|
466 |
|
467 |
-
return model_name, gr.update(), gr.update()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
468 |
|
469 |
self.model_name_dd.change(
|
470 |
change_model,
|
471 |
inputs=[self.model_name_dd],
|
472 |
-
outputs=[
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
473 |
|
474 |
########################################
|
475 |
def generate_gallery(text, images):
|
@@ -516,7 +655,6 @@ class ChatBotUI(object):
|
|
516 |
outputs=[self.text, self.gallery])
|
517 |
|
518 |
########################################
|
519 |
-
@spaces.GPU(duration=120)
|
520 |
def generate_video(message,
|
521 |
extend_prompt,
|
522 |
history,
|
@@ -527,6 +665,9 @@ class ChatBotUI(object):
|
|
527 |
fps,
|
528 |
seed,
|
529 |
progress=gr.Progress(track_tqdm=True)):
|
|
|
|
|
|
|
530 |
generator = torch.Generator(device='cuda').manual_seed(seed)
|
531 |
img_ids = re.findall('@(.*?)[ ,;.?$]', message)
|
532 |
if len(img_ids) == 0:
|
@@ -597,8 +738,12 @@ class ChatBotUI(object):
|
|
597 |
outputs=[self.history, self.chatbot, self.text, self.gallery])
|
598 |
|
599 |
########################################
|
600 |
-
@spaces.GPU(duration=
|
601 |
-
def run_chat(
|
|
|
|
|
|
|
|
|
602 |
extend_prompt,
|
603 |
history,
|
604 |
images,
|
@@ -607,6 +752,8 @@ class ChatBotUI(object):
|
|
607 |
negative_prompt,
|
608 |
cfg_scale,
|
609 |
rescale,
|
|
|
|
|
610 |
step,
|
611 |
seed,
|
612 |
output_h,
|
@@ -618,12 +765,25 @@ class ChatBotUI(object):
|
|
618 |
video_fps,
|
619 |
video_seed,
|
620 |
progress=gr.Progress(track_tqdm=True)):
|
|
|
|
|
|
|
|
|
|
|
|
|
621 |
retry_msg = message
|
622 |
gen_id = get_md5(message)[:12]
|
623 |
save_path = os.path.join(self.cache_dir, f'{gen_id}.png')
|
624 |
|
625 |
img_ids = re.findall('@(.*?)[ ,;.?$]', message)
|
626 |
history_io = None
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
627 |
new_message = message
|
628 |
|
629 |
if len(img_ids) > 0:
|
@@ -655,9 +815,9 @@ class ChatBotUI(object):
|
|
655 |
history_io = history_result[img_id]
|
656 |
|
657 |
buffered = io.BytesIO()
|
658 |
-
edit_image[0].save(buffered, format='
|
659 |
img_b64 = base64.b64encode(buffered.getvalue()).decode('utf-8')
|
660 |
-
img_str = f'<img src="data:image/
|
661 |
pre_info = f'Received one or more images, so image editing is conducted.\n The first input image @{img_ids[0]} is:\n {img_str}'
|
662 |
else:
|
663 |
pre_info = 'No image ids were found in the provided text prompt, so text-guided image generation is conducted. \n'
|
@@ -682,6 +842,9 @@ class ChatBotUI(object):
|
|
682 |
guide_scale=cfg_scale,
|
683 |
guide_rescale=rescale,
|
684 |
seed=seed,
|
|
|
|
|
|
|
685 |
)
|
686 |
|
687 |
img = imgs[0]
|
@@ -728,9 +891,9 @@ class ChatBotUI(object):
|
|
728 |
}
|
729 |
|
730 |
buffered = io.BytesIO()
|
731 |
-
img.convert('RGB').save(buffered, format='
|
732 |
img_b64 = base64.b64encode(buffered.getvalue()).decode('utf-8')
|
733 |
-
img_str = f'<img src="data:image/
|
734 |
|
735 |
history.append(
|
736 |
(message,
|
@@ -790,21 +953,25 @@ class ChatBotUI(object):
|
|
790 |
while len(history) >= self.max_msgs:
|
791 |
history.pop(0)
|
792 |
|
793 |
-
return history, images,
|
794 |
-
|
795 |
-
|
|
|
796 |
|
797 |
chat_inputs = [
|
|
|
798 |
self.extend_prompt, self.history, self.images, self.use_history,
|
799 |
self.history_result, self.negative_prompt, self.cfg_scale,
|
800 |
-
self.rescale, self.
|
|
|
801 |
self.output_width, self.video_auto, self.video_step,
|
802 |
self.video_frames, self.video_cfg_scale, self.video_fps,
|
803 |
self.video_seed
|
804 |
]
|
805 |
|
806 |
chat_outputs = [
|
807 |
-
self.history, self.images, self.
|
|
|
808 |
self.text, self.gallery, self.retry_msg
|
809 |
]
|
810 |
|
@@ -824,7 +991,7 @@ class ChatBotUI(object):
|
|
824 |
outputs=chat_outputs)
|
825 |
|
826 |
########################################
|
827 |
-
@spaces.GPU(duration=
|
828 |
def run_example(task, img, img_mask, ref1, prompt, seed):
|
829 |
edit_image, edit_image_mask, edit_task = [], [], []
|
830 |
if img is not None:
|
@@ -848,9 +1015,9 @@ class ChatBotUI(object):
|
|
848 |
edit_task.append('')
|
849 |
|
850 |
buffered = io.BytesIO()
|
851 |
-
img.save(buffered, format='
|
852 |
img_b64 = base64.b64encode(buffered.getvalue()).decode('utf-8')
|
853 |
-
img_str = f'<img src="data:image/
|
854 |
pre_info = f'Received one or more images, so image editing is conducted.\n The first input image is:\n {img_str}'
|
855 |
else:
|
856 |
pre_info = 'No image ids were found in the provided text prompt, so text-guided image generation is conducted. \n'
|
@@ -866,13 +1033,15 @@ class ChatBotUI(object):
|
|
866 |
prompt=[prompt] * img_num,
|
867 |
negative_prompt=[''] * img_num,
|
868 |
seed=seed,
|
|
|
|
|
869 |
)
|
870 |
|
871 |
img = imgs[0]
|
872 |
buffered = io.BytesIO()
|
873 |
-
img.convert('RGB').save(buffered, format='
|
874 |
img_b64 = base64.b64encode(buffered.getvalue()).decode('utf-8')
|
875 |
-
img_str = f'<img src="data:image/
|
876 |
history = [(prompt,
|
877 |
f'{pre_info} The generated image is:\n {img_str}')]
|
878 |
return self.get_history(history), gr.update(value=''), gr.update(
|
@@ -911,21 +1080,23 @@ class ChatBotUI(object):
|
|
911 |
return (gr.update(visible=True,
|
912 |
scale=1), gr.update(visible=True, scale=1),
|
913 |
gr.update(visible=True), gr.update(visible=False),
|
914 |
-
gr.update(visible=False), gr.update(visible=False)
|
|
|
915 |
|
916 |
self.upload_btn.click(upload_image,
|
917 |
inputs=[],
|
918 |
outputs=[
|
919 |
self.chat_page, self.editor_page,
|
920 |
self.upload_tab, self.edit_tab,
|
921 |
-
self.image_view_tab, self.video_view_tab
|
|
|
922 |
])
|
923 |
|
924 |
########################################
|
925 |
def edit_image(evt: gr.SelectData):
|
926 |
if isinstance(evt.value, str):
|
927 |
img_b64s = re.findall(
|
928 |
-
'<img src="data:image/
|
929 |
evt.value)
|
930 |
imgs = [
|
931 |
Image.open(io.BytesIO(base64.b64decode(copy.deepcopy(i))))
|
@@ -933,13 +1104,19 @@ class ChatBotUI(object):
|
|
933 |
]
|
934 |
if len(imgs) > 0:
|
935 |
if len(imgs) == 2:
|
936 |
-
|
|
|
|
|
|
|
937 |
edit_img = copy.deepcopy(imgs[-1])
|
938 |
else:
|
939 |
-
|
940 |
-
copy.deepcopy(imgs[-1])
|
941 |
-
|
942 |
-
|
|
|
|
|
|
|
943 |
edit_img = copy.deepcopy(imgs[-1])
|
944 |
|
945 |
return (gr.update(visible=True,
|
@@ -948,11 +1125,12 @@ class ChatBotUI(object):
|
|
948 |
gr.update(visible=False), gr.update(visible=True),
|
949 |
gr.update(visible=True), gr.update(visible=False),
|
950 |
gr.update(value=edit_img),
|
951 |
-
gr.update(value=view_img), gr.update(value=None)
|
|
|
952 |
else:
|
953 |
return (gr.update(), gr.update(), gr.update(), gr.update(),
|
954 |
gr.update(), gr.update(), gr.update(), gr.update(),
|
955 |
-
gr.update())
|
956 |
elif isinstance(evt.value, dict) and evt.value.get(
|
957 |
'component', '') == 'video':
|
958 |
value = evt.value['value']['video']['path']
|
@@ -960,11 +1138,12 @@ class ChatBotUI(object):
|
|
960 |
scale=1), gr.update(visible=True, scale=1),
|
961 |
gr.update(visible=False), gr.update(visible=False),
|
962 |
gr.update(visible=False), gr.update(visible=True),
|
963 |
-
gr.update(), gr.update(), gr.update(value=value)
|
|
|
964 |
else:
|
965 |
return (gr.update(), gr.update(), gr.update(), gr.update(),
|
966 |
gr.update(), gr.update(), gr.update(), gr.update(),
|
967 |
-
gr.update())
|
968 |
|
969 |
self.chatbot.select(edit_image,
|
970 |
outputs=[
|
@@ -972,16 +1151,17 @@ class ChatBotUI(object):
|
|
972 |
self.upload_tab, self.edit_tab,
|
973 |
self.image_view_tab, self.video_view_tab,
|
974 |
self.image_editor, self.image_viewer,
|
975 |
-
self.video_viewer
|
976 |
])
|
977 |
|
978 |
-
self.
|
979 |
-
|
980 |
-
|
|
|
981 |
|
982 |
########################################
|
983 |
def submit_upload_image(image, history, images):
|
984 |
-
history, images = self.add_uploaded_image_to_history(
|
985 |
image, history, images)
|
986 |
return gr.update(visible=False), gr.update(
|
987 |
visible=True), gr.update(
|
@@ -1151,14 +1331,14 @@ class ChatBotUI(object):
|
|
1151 |
thumbnail.save(thumbnail_path, format='JPEG')
|
1152 |
|
1153 |
buffered = io.BytesIO()
|
1154 |
-
img.convert('RGB').save(buffered, format='
|
1155 |
img_b64 = base64.b64encode(buffered.getvalue()).decode('utf-8')
|
1156 |
-
img_str = f'<img src="data:image/
|
1157 |
|
1158 |
buffered = io.BytesIO()
|
1159 |
-
mask.convert('RGB').save(buffered, format='
|
1160 |
mask_b64 = base64.b64encode(buffered.getvalue()).decode('utf-8')
|
1161 |
-
mask_str = f'<img src="data:image/
|
1162 |
|
1163 |
images[img_id] = {
|
1164 |
'image': save_path,
|
@@ -1207,22 +1387,21 @@ class ChatBotUI(object):
|
|
1207 |
}
|
1208 |
|
1209 |
buffered = io.BytesIO()
|
1210 |
-
img.convert('RGB').save(buffered, format='
|
1211 |
img_b64 = base64.b64encode(buffered.getvalue()).decode('utf-8')
|
1212 |
-
img_str = f'<img src="data:image/
|
1213 |
|
1214 |
history.append(
|
1215 |
(None,
|
1216 |
f'This is uploaded image:\n {img_str} image ID is: {img_id}'))
|
1217 |
-
return history, images
|
1218 |
|
1219 |
|
1220 |
|
1221 |
if __name__ == '__main__':
|
1222 |
-
|
1223 |
-
cfg = Config(cfg_file="config/chatbot_ui.yaml")
|
1224 |
with gr.Blocks() as demo:
|
1225 |
chatbot = ChatBotUI(cfg)
|
1226 |
chatbot.create_ui()
|
1227 |
chatbot.set_callbacks()
|
1228 |
-
demo.launch()
|
|
|
1 |
# -*- coding: utf-8 -*-
|
2 |
# Copyright (c) Alibaba, Inc. and its affiliates.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
3 |
import base64
|
4 |
import copy
|
5 |
import glob
|
6 |
import io
|
7 |
+
import os, csv, sys
|
8 |
import random
|
9 |
import re
|
10 |
+
import shlex
|
11 |
import string
|
12 |
+
import subprocess
|
13 |
import threading
|
14 |
import spaces
|
15 |
+
|
16 |
+
subprocess.run(shlex.split('pip install flash-attn --no-build-isolation'),
|
17 |
+
env=os.environ | {'FLASH_ATTENTION_SKIP_CUDA_BUILD': "TRUE"})
|
18 |
+
|
19 |
import cv2
|
20 |
import gradio as gr
|
21 |
import numpy as np
|
22 |
import torch
|
23 |
import transformers
|
|
|
|
|
|
|
24 |
from PIL import Image
|
25 |
from transformers import AutoModel, AutoTokenizer
|
26 |
|
27 |
+
from scepter.modules.inference.ace_inference import ACEInference
|
28 |
from scepter.modules.utils.config import Config
|
29 |
from scepter.modules.utils.directory import get_md5
|
30 |
from scepter.modules.utils.file_system import FS
|
31 |
from scepter.studio.utils.env import init_env
|
32 |
+
from importlib.metadata import version
|
33 |
|
|
|
34 |
from example import get_examples
|
35 |
from utils import load_image
|
36 |
|
37 |
+
csv.field_size_limit(sys.maxsize)
|
38 |
|
39 |
refresh_sty = '\U0001f504' # 🔄
|
40 |
clear_sty = '\U0001f5d1' # 🗑️
|
|
|
48 |
|
49 |
class ChatBotUI(object):
|
50 |
def __init__(self,
|
51 |
+
cfg_general_file,
|
52 |
+
is_debug=False,
|
53 |
+
language='en',
|
54 |
root_work_dir='./'):
|
55 |
+
try:
|
56 |
+
from diffusers import CogVideoXImageToVideoPipeline
|
57 |
+
from diffusers.utils import export_to_video
|
58 |
+
except Exception as e:
|
59 |
+
print(f"Import diffusers failed, please install or upgrade diffusers. Error information: {e}")
|
60 |
+
|
61 |
+
cfg = Config(cfg_file=cfg_general_file)
|
62 |
+
if cfg.have("FILE_SYSTEM"):
|
63 |
+
for file_sys in cfg.FILE_SYSTEM:
|
64 |
+
fs_prefix = FS.init_fs_client(file_sys)
|
65 |
+
else:
|
66 |
+
fs_prefix = FS.init_fs_client(cfg)
|
67 |
cfg.WORK_DIR = os.path.join(root_work_dir, cfg.WORK_DIR)
|
68 |
if not FS.exists(cfg.WORK_DIR):
|
69 |
FS.make_dir(cfg.WORK_DIR)
|
70 |
cfg = init_env(cfg)
|
71 |
self.cache_dir = cfg.WORK_DIR
|
72 |
+
self.chatbot_examples = get_examples(self.cache_dir) if not cfg.get('SKIP_EXAMPLES', False) else []
|
73 |
self.model_cfg_dir = cfg.MODEL.EDIT_MODEL.MODEL_CFG_DIR
|
74 |
self.model_yamls = glob.glob(os.path.join(self.model_cfg_dir,
|
75 |
'*.yaml'))
|
76 |
self.model_choices = dict()
|
77 |
+
self.default_model_name = ''
|
78 |
for i in self.model_yamls:
|
79 |
+
model_cfg = Config(load=True, cfg_file=i)
|
80 |
+
model_name = model_cfg.NAME
|
81 |
+
if model_cfg.IS_DEFAULT: self.default_model_name = model_name
|
82 |
+
self.model_choices[model_name] = model_cfg
|
83 |
+
print('Models: ', self.model_choices.keys())
|
84 |
+
|
85 |
+
#FS.get_from("ms://AI-ModelScope/FLUX.1-dev@flux1-dev.safetensors")
|
86 |
+
#FS.get_from("ms://AI-ModelScope/FLUX.1-dev@ae.safetensors")
|
87 |
+
#FS.get_dir_to_local_dir("ms://AI-ModelScope/FLUX.1-dev@text_encoder_2/")
|
88 |
+
#FS.get_dir_to_local_dir("ms://AI-ModelScope/FLUX.1-dev@tokenizer_2/")
|
89 |
+
#FS.get_dir_to_local_dir("ms://AI-ModelScope/FLUX.1-dev@text_encoder/")
|
90 |
+
#FS.get_dir_to_local_dir("ms://AI-ModelScope/FLUX.1-dev@tokenizer/")
|
91 |
+
|
92 |
+
assert len(self.model_choices) > 0
|
93 |
+
if self.default_model_name == "": self.default_model_name = self.model_choices.keys()[0]
|
94 |
+
self.model_name = self.default_model_name
|
95 |
self.pipe = ACEInference()
|
96 |
+
self.pipe.init_from_cfg(self.model_choices[self.default_model_name])
|
97 |
self.max_msgs = 20
|
|
|
98 |
self.enable_i2v = cfg.get('ENABLE_I2V', False)
|
99 |
+
self.gradio_version = version('gradio')
|
100 |
+
|
101 |
if self.enable_i2v:
|
102 |
self.i2v_model_dir = cfg.MODEL.I2V.MODEL_DIR
|
103 |
self.i2v_model_name = cfg.MODEL.I2V.MODEL_NAME
|
|
|
133 |
)
|
134 |
|
135 |
sys_prompt = """You are part of a team of bots that creates videos. You work with an assistant bot that will draw anything you say in square brackets.
|
|
|
136 |
For example , outputting " a beautiful morning in the woods with the sun peaking through the trees " will trigger your partner bot to output an video of a forest morning , as described. You will be prompted by people looking to create detailed , amazing videos. The way to accomplish this is to take their short prompts and make them extremely detailed and descriptive.
|
137 |
There are a few rules to follow:
|
|
|
138 |
You will only ever output a single video description per user request.
|
|
|
139 |
When modifications are requested , you should not simply make the description longer . You should refactor the entire description to integrate the suggestions.
|
140 |
Other times the user will not want modifications , but instead want a new image . In this case , you should ignore your previous conversation with the user.
|
|
|
141 |
Video descriptions must have the same num of words as examples below. Extra words will be ignored.
|
142 |
"""
|
143 |
self.enhance_ctx = [
|
|
|
184 |
]
|
185 |
|
186 |
def create_ui(self):
|
187 |
+
|
188 |
css = '.chatbot.prose.md {opacity: 1.0 !important} #chatbot {opacity: 1.0 !important}'
|
189 |
with gr.Blocks(css=css,
|
190 |
title='Chatbot',
|
|
|
195 |
self.history_result = gr.State(value={})
|
196 |
self.retry_msg = gr.State(value='')
|
197 |
with gr.Group():
|
198 |
+
self.ui_mode = gr.State(value='legacy')
|
199 |
+
with gr.Row(equal_height=True, visible=False) as self.chat_group:
|
200 |
with gr.Column(visible=True) as self.chat_page:
|
201 |
self.chatbot = gr.Chatbot(
|
202 |
height=600,
|
|
|
211 |
size='sm')
|
212 |
|
213 |
with gr.Column(visible=False) as self.editor_page:
|
214 |
+
with gr.Tabs(visible=False) as self.upload_tabs:
|
215 |
with gr.Tab(id='ImageUploader',
|
216 |
label='Image Uploader',
|
217 |
visible=True) as self.upload_tab:
|
|
|
220 |
interactive=True,
|
221 |
type='pil',
|
222 |
image_mode='RGB',
|
223 |
+
sources=['upload'],
|
224 |
elem_id='image_uploader',
|
225 |
format='png')
|
226 |
with gr.Row():
|
|
|
228 |
value='Submit',
|
229 |
elem_id='upload_submit')
|
230 |
self.ext_btn_1 = gr.Button(value='Exit')
|
231 |
+
with gr.Tabs(visible=False) as self.edit_tabs:
|
232 |
with gr.Tab(id='ImageEditor',
|
233 |
+
label='Image Editor') as self.edit_tab:
|
|
|
234 |
self.mask_type = gr.Dropdown(
|
235 |
label='Mask Type',
|
236 |
choices=[
|
|
|
293 |
self.ext_btn_2 = gr.Button(value='Exit')
|
294 |
|
295 |
with gr.Tab(id='ImageViewer',
|
296 |
+
label='Image Viewer') as self.image_view_tab:
|
297 |
+
if self.gradio_version >= '5.0.0':
|
298 |
+
self.image_viewer = gr.Image(
|
299 |
+
label='Image',
|
300 |
+
type='pil',
|
301 |
+
show_download_button=True,
|
302 |
+
elem_id='image_viewer')
|
303 |
+
else:
|
304 |
+
try:
|
305 |
+
from gradio_imageslider import ImageSlider
|
306 |
+
except Exception as e:
|
307 |
+
print(f"Import gradio_imageslider failed, please install.")
|
308 |
+
self.image_viewer = ImageSlider(
|
309 |
+
label='Image',
|
310 |
+
type='pil',
|
311 |
+
show_download_button=True,
|
312 |
+
elem_id='image_viewer')
|
313 |
|
314 |
self.ext_btn_3 = gr.Button(value='Exit')
|
315 |
|
|
|
328 |
|
329 |
self.ext_btn_4 = gr.Button(value='Exit')
|
330 |
|
331 |
+
with gr.Row(equal_height=True, visible=True) as self.legacy_group:
|
332 |
+
with gr.Column():
|
333 |
+
self.legacy_image_uploader = gr.Image(
|
334 |
+
height=550,
|
335 |
+
interactive=True,
|
336 |
+
type='pil',
|
337 |
+
image_mode='RGB',
|
338 |
+
elem_id='legacy_image_uploader',
|
339 |
+
format='png')
|
340 |
+
with gr.Column():
|
341 |
+
self.legacy_image_viewer = gr.Image(
|
342 |
+
label='Image',
|
343 |
+
height=550,
|
344 |
+
type='pil',
|
345 |
+
interactive=False,
|
346 |
+
show_download_button=True,
|
347 |
+
elem_id='image_viewer')
|
348 |
+
|
349 |
+
|
350 |
with gr.Accordion(label='Setting', open=False):
|
351 |
with gr.Row():
|
352 |
self.model_name_dd = gr.Dropdown(
|
353 |
choices=self.model_choices,
|
354 |
+
value=self.default_model_name,
|
355 |
label='Model Version')
|
356 |
|
357 |
with gr.Row():
|
|
|
362 |
label='Negative Prompt',
|
363 |
container=False)
|
364 |
|
365 |
+
with gr.Row():
|
366 |
+
# REFINER_PROMPT
|
367 |
+
self.refiner_prompt = gr.Textbox(
|
368 |
+
value=self.pipe.input.get("refiner_prompt", ""),
|
369 |
+
visible=self.pipe.input.get("refiner_prompt", None) is not None,
|
370 |
+
placeholder=
|
371 |
+
'Prompt used for refiner',
|
372 |
+
label='Refiner Prompt',
|
373 |
+
container=False)
|
374 |
+
|
375 |
+
|
376 |
with gr.Row():
|
377 |
with gr.Column(scale=8, min_width=500):
|
378 |
with gr.Row():
|
379 |
self.step = gr.Slider(minimum=1,
|
380 |
maximum=1000,
|
381 |
+
value=self.pipe.input.get("sample_steps", 20),
|
382 |
+
visible=self.pipe.input.get("sample_steps", None) is not None,
|
383 |
label='Sample Step')
|
384 |
self.cfg_scale = gr.Slider(
|
385 |
minimum=1.0,
|
386 |
maximum=20.0,
|
387 |
+
value=self.pipe.input.get("guide_scale", 4.5),
|
388 |
+
visible=self.pipe.input.get("guide_scale", None) is not None,
|
389 |
label='Guidance Scale')
|
390 |
self.rescale = gr.Slider(minimum=0.0,
|
391 |
maximum=1.0,
|
392 |
+
value=self.pipe.input.get("guide_rescale", 0.5),
|
393 |
+
visible=self.pipe.input.get("guide_rescale", None) is not None,
|
394 |
label='Rescale')
|
395 |
+
self.refiner_scale = gr.Slider(minimum=-0.1,
|
396 |
+
maximum=1.0,
|
397 |
+
value=self.pipe.input.get("refiner_scale", 0.5),
|
398 |
+
visible=self.pipe.input.get("refiner_scale", None) is not None,
|
399 |
+
label='Refiner Scale')
|
400 |
self.seed = gr.Slider(minimum=-1,
|
401 |
maximum=10000000,
|
402 |
value=-1,
|
403 |
label='Seed')
|
404 |
self.output_height = gr.Slider(
|
405 |
minimum=256,
|
406 |
+
maximum=1440,
|
407 |
+
value=self.pipe.input.get("output_height", 1024),
|
408 |
+
visible=self.pipe.input.get("output_height", None) is not None,
|
409 |
label='Output Height')
|
410 |
self.output_width = gr.Slider(
|
411 |
minimum=256,
|
412 |
+
maximum=1440,
|
413 |
+
value=self.pipe.input.get("output_width", 1024),
|
414 |
+
visible=self.pipe.input.get("output_width", None) is not None,
|
415 |
label='Output Width')
|
416 |
with gr.Column(scale=1, min_width=50):
|
417 |
self.use_history = gr.Checkbox(value=False,
|
418 |
label='Use History')
|
419 |
+
self.use_ace = gr.Checkbox(value=self.pipe.input.get("use_ace", True),
|
420 |
+
visible=self.pipe.input.get("use_ace", None) is not None,
|
421 |
+
label='Use ACE')
|
422 |
self.video_auto = gr.Checkbox(
|
423 |
value=False,
|
424 |
label='Auto Gen Video',
|
|
|
455 |
visible=True)
|
456 |
|
457 |
with gr.Row():
|
458 |
+
self.chatbot_inst = """
|
459 |
**Instruction**:
|
|
|
460 |
1. Click 'Upload' button to upload one or more images as input images.
|
461 |
2. Enter '@' in the text box will exhibit all images in the gallery.
|
462 |
3. Select the image you wish to edit from the gallery, and its Image ID will be displayed in the text box.
|
|
|
466 |
6. **Important** To render text on an image, please ensure to include a space between each letter. For instance, "add text 'g i r l' on the mask area of @xxxxx".
|
467 |
7. To implement local editing based on a specified mask, simply click on the image within the chat window to access the image editor. Here, you can draw a mask and then click the 'Submit' button to upload the edited image along with the mask. For inpainting tasks, select the 'Composite' mask type, while for outpainting tasks, choose the 'Outpainting' mask type. For all other local editing tasks, please select the 'Background' mask type.
|
468 |
8. If you find our work valuable, we invite you to refer to the [ACE Page](https://ali-vilab.github.io/ace-page/) for comprehensive information.
|
|
|
469 |
"""
|
470 |
+
|
471 |
+
self.legacy_inst = """
|
472 |
+
**Instruction**:
|
473 |
+
1. You can edit the image by uploading it; if no image is uploaded, an image will be generated from text..
|
474 |
+
2. Enter '@' in the text box will exhibit all images in the gallery.
|
475 |
+
3. Select the image you wish to edit from the gallery, and its Image ID will be displayed in the text box.
|
476 |
+
4. **Important** To render text on an image, please ensure to include a space between each letter. For instance, "add text 'g i r l' on the mask area of @xxxxx".
|
477 |
+
5. To perform multi-step editing, partial editing, inpainting, outpainting, and other operations, please click the Chatbot Checkbox to enable the conversational editing mode and follow the relevant instructions..
|
478 |
+
6. If you find our work valuable, we invite you to refer to the [ACE Page](https://ali-vilab.github.io/ace-page/) for comprehensive information.
|
479 |
+
"""
|
480 |
+
|
481 |
+
self.instruction = gr.Markdown(value=self.legacy_inst)
|
482 |
+
|
483 |
with gr.Row(variant='panel',
|
484 |
equal_height=True,
|
485 |
show_progress=False):
|
486 |
+
with gr.Column(scale=1, min_width=100, visible=False) as self.upload_panel:
|
487 |
self.upload_btn = gr.Button(value=upload_sty +
|
488 |
' Upload',
|
489 |
variant='secondary')
|
|
|
493 |
label='Instruction',
|
494 |
container=False)
|
495 |
with gr.Column(scale=1, min_width=100):
|
496 |
+
self.chat_btn = gr.Button(value='Generate',
|
497 |
variant='primary')
|
498 |
with gr.Column(scale=1, min_width=100):
|
499 |
self.retry_btn = gr.Button(value=refresh_sty +
|
500 |
' Retry',
|
501 |
variant='secondary')
|
502 |
+
with gr.Column(scale=1, min_width=100):
|
503 |
+
self.mode_checkbox = gr.Checkbox(
|
504 |
+
value=False,
|
505 |
+
label='ChatBot')
|
506 |
with gr.Column(scale=(1 if self.enable_i2v else 0),
|
507 |
min_width=0):
|
508 |
self.video_gen_btn = gr.Button(value=video_sty +
|
|
|
538 |
lock.acquire()
|
539 |
del self.pipe
|
540 |
torch.cuda.empty_cache()
|
|
|
|
|
541 |
self.pipe = ACEInference()
|
542 |
+
self.pipe.init_from_cfg(self.model_choices[model_name])
|
543 |
self.model_name = model_name
|
544 |
lock.release()
|
545 |
|
546 |
+
return (model_name, gr.update(), gr.update(),
|
547 |
+
gr.Slider(
|
548 |
+
value=self.pipe.input.get("sample_steps", 20),
|
549 |
+
visible=self.pipe.input.get("sample_steps", None) is not None),
|
550 |
+
gr.Slider(
|
551 |
+
value=self.pipe.input.get("guide_scale", 4.5),
|
552 |
+
visible=self.pipe.input.get("guide_scale", None) is not None),
|
553 |
+
gr.Slider(
|
554 |
+
value=self.pipe.input.get("guide_rescale", 0.5),
|
555 |
+
visible=self.pipe.input.get("guide_rescale", None) is not None),
|
556 |
+
gr.Slider(
|
557 |
+
value=self.pipe.input.get("output_height", 1024),
|
558 |
+
visible=self.pipe.input.get("output_height", None) is not None),
|
559 |
+
gr.Slider(
|
560 |
+
value=self.pipe.input.get("output_width", 1024),
|
561 |
+
visible=self.pipe.input.get("output_width", None) is not None),
|
562 |
+
gr.Textbox(
|
563 |
+
value=self.pipe.input.get("refiner_prompt", ""),
|
564 |
+
visible=self.pipe.input.get("refiner_prompt", None) is not None),
|
565 |
+
gr.Slider(
|
566 |
+
value=self.pipe.input.get("refiner_scale", 0.5),
|
567 |
+
visible=self.pipe.input.get("refiner_scale", None) is not None
|
568 |
+
),
|
569 |
+
gr.Checkbox(
|
570 |
+
value=self.pipe.input.get("use_ace", True),
|
571 |
+
visible=self.pipe.input.get("use_ace", None) is not None
|
572 |
+
)
|
573 |
+
)
|
574 |
|
575 |
self.model_name_dd.change(
|
576 |
change_model,
|
577 |
inputs=[self.model_name_dd],
|
578 |
+
outputs=[
|
579 |
+
self.model_name_dd, self.chatbot, self.text,
|
580 |
+
self.step,
|
581 |
+
self.cfg_scale, self.rescale, self.output_height,
|
582 |
+
self.output_width, self.refiner_prompt, self.refiner_scale,
|
583 |
+
self.use_ace])
|
584 |
+
|
585 |
+
|
586 |
+
def mode_change(mode_check):
|
587 |
+
if mode_check:
|
588 |
+
# ChatBot
|
589 |
+
return (
|
590 |
+
gr.Row(visible=False),
|
591 |
+
gr.Row(visible=True),
|
592 |
+
gr.Button(value='Generate'),
|
593 |
+
gr.State(value='chatbot'),
|
594 |
+
gr.Column(visible=True),
|
595 |
+
gr.Markdown(value=self.chatbot_inst)
|
596 |
+
)
|
597 |
+
else:
|
598 |
+
# Legacy
|
599 |
+
return (
|
600 |
+
gr.Row(visible=True),
|
601 |
+
gr.Row(visible=False),
|
602 |
+
gr.Button(value=chat_sty + ' Chat'),
|
603 |
+
gr.State(value='legacy'),
|
604 |
+
gr.Column(visible=False),
|
605 |
+
gr.Markdown(value=self.legacy_inst)
|
606 |
+
)
|
607 |
+
self.mode_checkbox.change(mode_change, inputs=[self.mode_checkbox],
|
608 |
+
outputs=[self.legacy_group, self.chat_group,
|
609 |
+
self.chat_btn, self.ui_mode,
|
610 |
+
self.upload_panel, self.instruction])
|
611 |
+
|
612 |
|
613 |
########################################
|
614 |
def generate_gallery(text, images):
|
|
|
655 |
outputs=[self.text, self.gallery])
|
656 |
|
657 |
########################################
|
|
|
658 |
def generate_video(message,
|
659 |
extend_prompt,
|
660 |
history,
|
|
|
665 |
fps,
|
666 |
seed,
|
667 |
progress=gr.Progress(track_tqdm=True)):
|
668 |
+
|
669 |
+
from diffusers.utils import export_to_video
|
670 |
+
|
671 |
generator = torch.Generator(device='cuda').manual_seed(seed)
|
672 |
img_ids = re.findall('@(.*?)[ ,;.?$]', message)
|
673 |
if len(img_ids) == 0:
|
|
|
738 |
outputs=[self.history, self.chatbot, self.text, self.gallery])
|
739 |
|
740 |
########################################
|
741 |
+
@spaces.GPU(duration=240)
|
742 |
+
def run_chat(
|
743 |
+
message,
|
744 |
+
legacy_image,
|
745 |
+
ui_mode,
|
746 |
+
use_ace,
|
747 |
extend_prompt,
|
748 |
history,
|
749 |
images,
|
|
|
752 |
negative_prompt,
|
753 |
cfg_scale,
|
754 |
rescale,
|
755 |
+
refiner_prompt,
|
756 |
+
refiner_scale,
|
757 |
step,
|
758 |
seed,
|
759 |
output_h,
|
|
|
765 |
video_fps,
|
766 |
video_seed,
|
767 |
progress=gr.Progress(track_tqdm=True)):
|
768 |
+
legacy_img_ids = []
|
769 |
+
if ui_mode == 'legacy':
|
770 |
+
if legacy_image is not None:
|
771 |
+
history, images, img_id = self.add_uploaded_image_to_history(
|
772 |
+
legacy_image, history, images)
|
773 |
+
legacy_img_ids.append(img_id)
|
774 |
retry_msg = message
|
775 |
gen_id = get_md5(message)[:12]
|
776 |
save_path = os.path.join(self.cache_dir, f'{gen_id}.png')
|
777 |
|
778 |
img_ids = re.findall('@(.*?)[ ,;.?$]', message)
|
779 |
history_io = None
|
780 |
+
|
781 |
+
if len(img_ids) < 1:
|
782 |
+
img_ids = legacy_img_ids
|
783 |
+
for img_id in img_ids:
|
784 |
+
if f'@{img_id}' not in message:
|
785 |
+
message = f'@{img_id} ' + message
|
786 |
+
|
787 |
new_message = message
|
788 |
|
789 |
if len(img_ids) > 0:
|
|
|
815 |
history_io = history_result[img_id]
|
816 |
|
817 |
buffered = io.BytesIO()
|
818 |
+
edit_image[0].save(buffered, format='PNG')
|
819 |
img_b64 = base64.b64encode(buffered.getvalue()).decode('utf-8')
|
820 |
+
img_str = f'<img src="data:image/png;base64,{img_b64}" style="pointer-events: none;">'
|
821 |
pre_info = f'Received one or more images, so image editing is conducted.\n The first input image @{img_ids[0]} is:\n {img_str}'
|
822 |
else:
|
823 |
pre_info = 'No image ids were found in the provided text prompt, so text-guided image generation is conducted. \n'
|
|
|
842 |
guide_scale=cfg_scale,
|
843 |
guide_rescale=rescale,
|
844 |
seed=seed,
|
845 |
+
refiner_prompt=refiner_prompt,
|
846 |
+
refiner_scale=refiner_scale,
|
847 |
+
use_ace=use_ace
|
848 |
)
|
849 |
|
850 |
img = imgs[0]
|
|
|
891 |
}
|
892 |
|
893 |
buffered = io.BytesIO()
|
894 |
+
img.convert('RGB').save(buffered, format='PNG')
|
895 |
img_b64 = base64.b64encode(buffered.getvalue()).decode('utf-8')
|
896 |
+
img_str = f'<img src="data:image/png;base64,{img_b64}" style="pointer-events: none;">'
|
897 |
|
898 |
history.append(
|
899 |
(message,
|
|
|
953 |
while len(history) >= self.max_msgs:
|
954 |
history.pop(0)
|
955 |
|
956 |
+
return (history, images, gr.Image(value=save_path),
|
957 |
+
history_result, self.get_history(
|
958 |
+
history), gr.update(), gr.update(
|
959 |
+
visible=False), retry_msg)
|
960 |
|
961 |
chat_inputs = [
|
962 |
+
self.legacy_image_uploader, self.ui_mode, self.use_ace,
|
963 |
self.extend_prompt, self.history, self.images, self.use_history,
|
964 |
self.history_result, self.negative_prompt, self.cfg_scale,
|
965 |
+
self.rescale, self.refiner_prompt, self.refiner_scale,
|
966 |
+
self.step, self.seed, self.output_height,
|
967 |
self.output_width, self.video_auto, self.video_step,
|
968 |
self.video_frames, self.video_cfg_scale, self.video_fps,
|
969 |
self.video_seed
|
970 |
]
|
971 |
|
972 |
chat_outputs = [
|
973 |
+
self.history, self.images, self.legacy_image_viewer,
|
974 |
+
self.history_result, self.chatbot,
|
975 |
self.text, self.gallery, self.retry_msg
|
976 |
]
|
977 |
|
|
|
991 |
outputs=chat_outputs)
|
992 |
|
993 |
########################################
|
994 |
+
@spaces.GPU(duration=120)
|
995 |
def run_example(task, img, img_mask, ref1, prompt, seed):
|
996 |
edit_image, edit_image_mask, edit_task = [], [], []
|
997 |
if img is not None:
|
|
|
1015 |
edit_task.append('')
|
1016 |
|
1017 |
buffered = io.BytesIO()
|
1018 |
+
img.save(buffered, format='PNG')
|
1019 |
img_b64 = base64.b64encode(buffered.getvalue()).decode('utf-8')
|
1020 |
+
img_str = f'<img src="data:image/png;base64,{img_b64}" style="pointer-events: none;">'
|
1021 |
pre_info = f'Received one or more images, so image editing is conducted.\n The first input image is:\n {img_str}'
|
1022 |
else:
|
1023 |
pre_info = 'No image ids were found in the provided text prompt, so text-guided image generation is conducted. \n'
|
|
|
1033 |
prompt=[prompt] * img_num,
|
1034 |
negative_prompt=[''] * img_num,
|
1035 |
seed=seed,
|
1036 |
+
refiner_prompt=self.pipe.input.get("refiner_prompt", ""),
|
1037 |
+
refiner_scale=self.pipe.input.get("refiner_scale", 0.0),
|
1038 |
)
|
1039 |
|
1040 |
img = imgs[0]
|
1041 |
buffered = io.BytesIO()
|
1042 |
+
img.convert('RGB').save(buffered, format='PNG')
|
1043 |
img_b64 = base64.b64encode(buffered.getvalue()).decode('utf-8')
|
1044 |
+
img_str = f'<img src="data:image/png;base64,{img_b64}" style="pointer-events: none;">'
|
1045 |
history = [(prompt,
|
1046 |
f'{pre_info} The generated image is:\n {img_str}')]
|
1047 |
return self.get_history(history), gr.update(value=''), gr.update(
|
|
|
1080 |
return (gr.update(visible=True,
|
1081 |
scale=1), gr.update(visible=True, scale=1),
|
1082 |
gr.update(visible=True), gr.update(visible=False),
|
1083 |
+
gr.update(visible=False), gr.update(visible=False),
|
1084 |
+
gr.update(visible=True))
|
1085 |
|
1086 |
self.upload_btn.click(upload_image,
|
1087 |
inputs=[],
|
1088 |
outputs=[
|
1089 |
self.chat_page, self.editor_page,
|
1090 |
self.upload_tab, self.edit_tab,
|
1091 |
+
self.image_view_tab, self.video_view_tab,
|
1092 |
+
self.upload_tabs
|
1093 |
])
|
1094 |
|
1095 |
########################################
|
1096 |
def edit_image(evt: gr.SelectData):
|
1097 |
if isinstance(evt.value, str):
|
1098 |
img_b64s = re.findall(
|
1099 |
+
'<img src="data:image/png;base64,(.*?)" style="pointer-events: none;">',
|
1100 |
evt.value)
|
1101 |
imgs = [
|
1102 |
Image.open(io.BytesIO(base64.b64decode(copy.deepcopy(i))))
|
|
|
1104 |
]
|
1105 |
if len(imgs) > 0:
|
1106 |
if len(imgs) == 2:
|
1107 |
+
if self.gradio_version >= '5.0.0':
|
1108 |
+
view_img = copy.deepcopy(imgs[-1])
|
1109 |
+
else:
|
1110 |
+
view_img = copy.deepcopy(imgs)
|
1111 |
edit_img = copy.deepcopy(imgs[-1])
|
1112 |
else:
|
1113 |
+
if self.gradio_version >= '5.0.0':
|
1114 |
+
view_img = copy.deepcopy(imgs[-1])
|
1115 |
+
else:
|
1116 |
+
view_img = [
|
1117 |
+
copy.deepcopy(imgs[-1]),
|
1118 |
+
copy.deepcopy(imgs[-1])
|
1119 |
+
]
|
1120 |
edit_img = copy.deepcopy(imgs[-1])
|
1121 |
|
1122 |
return (gr.update(visible=True,
|
|
|
1125 |
gr.update(visible=False), gr.update(visible=True),
|
1126 |
gr.update(visible=True), gr.update(visible=False),
|
1127 |
gr.update(value=edit_img),
|
1128 |
+
gr.update(value=view_img), gr.update(value=None),
|
1129 |
+
gr.update(visible=True))
|
1130 |
else:
|
1131 |
return (gr.update(), gr.update(), gr.update(), gr.update(),
|
1132 |
gr.update(), gr.update(), gr.update(), gr.update(),
|
1133 |
+
gr.update(), gr.update())
|
1134 |
elif isinstance(evt.value, dict) and evt.value.get(
|
1135 |
'component', '') == 'video':
|
1136 |
value = evt.value['value']['video']['path']
|
|
|
1138 |
scale=1), gr.update(visible=True, scale=1),
|
1139 |
gr.update(visible=False), gr.update(visible=False),
|
1140 |
gr.update(visible=False), gr.update(visible=True),
|
1141 |
+
gr.update(), gr.update(), gr.update(value=value),
|
1142 |
+
gr.update())
|
1143 |
else:
|
1144 |
return (gr.update(), gr.update(), gr.update(), gr.update(),
|
1145 |
gr.update(), gr.update(), gr.update(), gr.update(),
|
1146 |
+
gr.update(), gr.update())
|
1147 |
|
1148 |
self.chatbot.select(edit_image,
|
1149 |
outputs=[
|
|
|
1151 |
self.upload_tab, self.edit_tab,
|
1152 |
self.image_view_tab, self.video_view_tab,
|
1153 |
self.image_editor, self.image_viewer,
|
1154 |
+
self.video_viewer, self.edit_tabs
|
1155 |
])
|
1156 |
|
1157 |
+
if self.gradio_version < '5.0.0':
|
1158 |
+
self.image_viewer.change(lambda x: x,
|
1159 |
+
inputs=self.image_viewer,
|
1160 |
+
outputs=self.image_viewer)
|
1161 |
|
1162 |
########################################
|
1163 |
def submit_upload_image(image, history, images):
|
1164 |
+
history, images, _ = self.add_uploaded_image_to_history(
|
1165 |
image, history, images)
|
1166 |
return gr.update(visible=False), gr.update(
|
1167 |
visible=True), gr.update(
|
|
|
1331 |
thumbnail.save(thumbnail_path, format='JPEG')
|
1332 |
|
1333 |
buffered = io.BytesIO()
|
1334 |
+
img.convert('RGB').save(buffered, format='PNG')
|
1335 |
img_b64 = base64.b64encode(buffered.getvalue()).decode('utf-8')
|
1336 |
+
img_str = f'<img src="data:image/png;base64,{img_b64}" style="pointer-events: none;">'
|
1337 |
|
1338 |
buffered = io.BytesIO()
|
1339 |
+
mask.convert('RGB').save(buffered, format='PNG')
|
1340 |
mask_b64 = base64.b64encode(buffered.getvalue()).decode('utf-8')
|
1341 |
+
mask_str = f'<img src="data:image/png;base64,{mask_b64}" style="pointer-events: none;">'
|
1342 |
|
1343 |
images[img_id] = {
|
1344 |
'image': save_path,
|
|
|
1387 |
}
|
1388 |
|
1389 |
buffered = io.BytesIO()
|
1390 |
+
img.convert('RGB').save(buffered, format='PNG')
|
1391 |
img_b64 = base64.b64encode(buffered.getvalue()).decode('utf-8')
|
1392 |
+
img_str = f'<img src="data:image/png;base64,{img_b64}" style="pointer-events: none;">'
|
1393 |
|
1394 |
history.append(
|
1395 |
(None,
|
1396 |
f'This is uploaded image:\n {img_str} image ID is: {img_id}'))
|
1397 |
+
return history, images, img_id
|
1398 |
|
1399 |
|
1400 |
|
1401 |
if __name__ == '__main__':
|
1402 |
+
cfg = "config/chatbot_ui.yaml"
|
|
|
1403 |
with gr.Blocks() as demo:
|
1404 |
chatbot = ChatBotUI(cfg)
|
1405 |
chatbot.create_ui()
|
1406 |
chatbot.set_callbacks()
|
1407 |
+
demo.launch()
|