Spaces:
Runtime error
Runtime error
ryanzhangfan
commited on
Commit
•
9aa6aea
1
Parent(s):
f8f41f8
initial commit
Browse files- app.py +66 -0
- demo/__init__.py +0 -0
- demo/__pycache__/__init__.cpython-310.pyc +0 -0
- demo/__pycache__/chat_frontend.cpython-310.pyc +0 -0
- demo/__pycache__/generation_frontend.cpython-310.pyc +0 -0
- demo/__pycache__/meta.cpython-310.pyc +0 -0
- demo/__pycache__/utils.cpython-310.pyc +0 -0
- demo/chat_frontend.py +249 -0
- demo/generation_frontend.py +247 -0
- demo/meta.py +298 -0
- demo/utils.py +87 -0
- requirements.txt +5 -0
app.py
ADDED
@@ -0,0 +1,66 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# -*- coding: utf-8 -*-
|
2 |
+
|
3 |
+
# ===================================================
|
4 |
+
#
|
5 |
+
# Author : Fan Zhang
|
6 |
+
# Email : zhangfan@baai.ac.cn
|
7 |
+
# Institute : Beijing Academy of Artificial Intelligence (BAAI)
|
8 |
+
# Create On : 2023-12-11 15:34
|
9 |
+
# Last Modified : 2023-12-20 03:59
|
10 |
+
# File Name : frontend.py
|
11 |
+
# Description :
|
12 |
+
#
|
13 |
+
# ===================================================
|
14 |
+
|
15 |
+
import argparse
|
16 |
+
import os
|
17 |
+
|
18 |
+
import gradio as gr
|
19 |
+
from demo.generation_frontend import build_generation
|
20 |
+
from demo.chat_frontend import build_chat
|
21 |
+
|
22 |
+
parser = argparse.ArgumentParser()
|
23 |
+
parser.add_argument("--title", type=str, default='Emu')
|
24 |
+
|
25 |
+
parser.add_argument("--host", type=str, default="0.0.0.0")
|
26 |
+
parser.add_argument("--port", type=int, default=9002)
|
27 |
+
parser.add_argument("--share", action="store_true")
|
28 |
+
parser.add_argument("--controller-url", type=str, default="http://218.91.113.230:9002")
|
29 |
+
parser.add_argument("--concurrency-count", type=int, default=8)
|
30 |
+
parser.add_argument("--disable-chat", action="store_true")
|
31 |
+
parser.add_argument("--disable-generate", action="store_true")
|
32 |
+
|
33 |
+
args = parser.parse_args()
|
34 |
+
|
35 |
+
|
36 |
+
if __name__ == "__main__":
|
37 |
+
title = "EmuV2: An Open Multimodal Generalist"
|
38 |
+
os.makedirs("log", exist_ok=True)
|
39 |
+
|
40 |
+
interface_list, tab_names = [], []
|
41 |
+
if not args.disable_generate:
|
42 |
+
demo_generation = build_generation(args)
|
43 |
+
interface_list.append(demo_generation)
|
44 |
+
tab_names.append("Multi-modal Generation")
|
45 |
+
|
46 |
+
if not args.disable_chat:
|
47 |
+
demo_chat = build_chat(args)
|
48 |
+
interface_list.append(demo_chat)
|
49 |
+
tab_names.append("Multi-modal Chat")
|
50 |
+
|
51 |
+
demo_all = gr.TabbedInterface(
|
52 |
+
interface_list=interface_list,
|
53 |
+
tab_names=tab_names,
|
54 |
+
title=title,
|
55 |
+
theme=gr.themes.Default(primary_hue="blue", secondary_hue="blue"),
|
56 |
+
)
|
57 |
+
|
58 |
+
demo_all.queue(
|
59 |
+
concurrency_count=args.concurrency_count,
|
60 |
+
status_update_rate=3,
|
61 |
+
api_open=False,
|
62 |
+
).launch(
|
63 |
+
enable_queue=True,
|
64 |
+
server_name=args.host, server_port=args.port,
|
65 |
+
share=args.share,
|
66 |
+
)
|
demo/__init__.py
ADDED
File without changes
|
demo/__pycache__/__init__.cpython-310.pyc
ADDED
Binary file (163 Bytes). View file
|
|
demo/__pycache__/chat_frontend.cpython-310.pyc
ADDED
Binary file (4.93 kB). View file
|
|
demo/__pycache__/generation_frontend.cpython-310.pyc
ADDED
Binary file (5.03 kB). View file
|
|
demo/__pycache__/meta.cpython-310.pyc
ADDED
Binary file (8.26 kB). View file
|
|
demo/__pycache__/utils.cpython-310.pyc
ADDED
Binary file (2.06 kB). View file
|
|
demo/chat_frontend.py
ADDED
@@ -0,0 +1,249 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# -*- coding: utf-8 -*-
|
2 |
+
|
3 |
+
# ===================================================
|
4 |
+
#
|
5 |
+
# Author : Fan Zhang
|
6 |
+
# Email : zhangfan@baai.ac.cn
|
7 |
+
# Institute : Beijing Academy of Artificial Intelligence (BAAI)
|
8 |
+
# Create On : 2023-12-12 18:05
|
9 |
+
# Last Modified : 2023-12-19 15:00
|
10 |
+
# File Name : chat_frontend.py
|
11 |
+
# Description :
|
12 |
+
#
|
13 |
+
# ===================================================
|
14 |
+
|
15 |
+
import json
|
16 |
+
import io
|
17 |
+
import time
|
18 |
+
from PIL import Image
|
19 |
+
import requests
|
20 |
+
|
21 |
+
import gradio as gr
|
22 |
+
|
23 |
+
from .meta import ConvMeta, Role, DataMeta
|
24 |
+
from .utils import extract_frames
|
25 |
+
from .utils import frontend_logger as logging
|
26 |
+
|
27 |
+
CONTROLLER_URL = ""
|
28 |
+
|
29 |
+
def submit(
|
30 |
+
meta,
|
31 |
+
image,
|
32 |
+
video,
|
33 |
+
text,
|
34 |
+
num_frames,
|
35 |
+
):
|
36 |
+
if meta is None:
|
37 |
+
meta = ConvMeta()
|
38 |
+
|
39 |
+
meta.pop_error()
|
40 |
+
|
41 |
+
check_text = (text != "" and text is not None)
|
42 |
+
check_image = image is not None
|
43 |
+
check_video = video is not None
|
44 |
+
|
45 |
+
if check_text + check_image + check_video != 1:
|
46 |
+
logging.info(f"{meta.log_id}: invalid input: give multi madality simultaneously for single modality input")
|
47 |
+
gr.Error("Invalid input number, must give exactly one modality input at a time")
|
48 |
+
return meta.format_chatbot(), meta, None, None, ""
|
49 |
+
|
50 |
+
if check_text:
|
51 |
+
meta.append(Role.USER, DataMeta.build(text=text))
|
52 |
+
elif check_image:
|
53 |
+
meta.append(Role.USER, DataMeta.build(image=image))
|
54 |
+
elif check_video:
|
55 |
+
frames = extract_frames(video, num_frames)
|
56 |
+
meta.append(Role.USER, DataMeta.build(frames=frames))
|
57 |
+
|
58 |
+
return meta.format_chatbot(), meta, None, None, ""
|
59 |
+
|
60 |
+
|
61 |
+
def clear_history(meta):
|
62 |
+
if meta is None:
|
63 |
+
meta = ConvMeta()
|
64 |
+
meta.clear()
|
65 |
+
return meta.format_chatbot(), meta
|
66 |
+
|
67 |
+
|
68 |
+
def generate(
|
69 |
+
meta,
|
70 |
+
do_sample,
|
71 |
+
max_new_tokens,
|
72 |
+
temperature,
|
73 |
+
top_k,
|
74 |
+
top_p,
|
75 |
+
length_penalty,
|
76 |
+
num_beams,
|
77 |
+
repetition_penalty,
|
78 |
+
):
|
79 |
+
if meta is None:
|
80 |
+
meta = ConvMeta()
|
81 |
+
|
82 |
+
meta.pop_error()
|
83 |
+
meta.pop()
|
84 |
+
prompt = meta.format_chat()
|
85 |
+
|
86 |
+
prompt_list, image_list = [], {}
|
87 |
+
for idx, p in enumerate(prompt):
|
88 |
+
if isinstance(p, Image.Image):
|
89 |
+
key = f"[<IMAGE{idx}>]"
|
90 |
+
prompt_list.append(["IMAGE", key])
|
91 |
+
|
92 |
+
buf = io.BytesIO()
|
93 |
+
p.save(buf, format="PNG")
|
94 |
+
image_list[key] = (key, io.BytesIO(buf.getvalue()), "image/png")
|
95 |
+
else:
|
96 |
+
prompt_list.append(["TEXT", p])
|
97 |
+
|
98 |
+
if len(image_list) == 0:
|
99 |
+
image_list = None
|
100 |
+
|
101 |
+
logging.info(f"{meta.log_id}: construct chat reqeust with prompt {prompt_list}")
|
102 |
+
|
103 |
+
t0 = time.time()
|
104 |
+
try:
|
105 |
+
print(do_sample)
|
106 |
+
rsp = requests.post(
|
107 |
+
CONTROLLER_URL + "/v1/mmc",
|
108 |
+
files=image_list,
|
109 |
+
data={
|
110 |
+
"log_id": meta.log_id,
|
111 |
+
"prompt": json.dumps(prompt_list),
|
112 |
+
"do_sample": do_sample,
|
113 |
+
"max_new_tokens": max_new_tokens,
|
114 |
+
"temperature": temperature,
|
115 |
+
"top_k": top_k,
|
116 |
+
"top_p": top_p,
|
117 |
+
"length_penalty": length_penalty,
|
118 |
+
"num_beams": num_beams,
|
119 |
+
"repetition_penalty": repetition_penalty,
|
120 |
+
},
|
121 |
+
)
|
122 |
+
except:
|
123 |
+
rsp = requests.Response()
|
124 |
+
rsp.status_code = 1099
|
125 |
+
t1 = time.time()
|
126 |
+
|
127 |
+
logging.info(f"{meta.log_id}: get response with status code: {rsp.status_code}, time: {(t1-t0)*1000:.3f}ms")
|
128 |
+
|
129 |
+
if rsp.ok:
|
130 |
+
content = json.loads(rsp.text)
|
131 |
+
if content["code"] == 0:
|
132 |
+
meta.append(Role.ASSISTANT, DataMeta.build(text=content["data"]))
|
133 |
+
else:
|
134 |
+
meta.append(Role.ASSISTANT, DataMeta.build(text=f"GENERATE FAILED: {content['data']}"), is_error=True)
|
135 |
+
else:
|
136 |
+
meta.append(Role.ASSISTANT, DataMeta.build(text=f"GENERATE FAILED: http failed with code {rsp.status_code}"), is_error=True)
|
137 |
+
|
138 |
+
return meta.format_chatbot(), meta
|
139 |
+
|
140 |
+
|
141 |
+
def build_chat(args):
|
142 |
+
global CONTROLLER_URL
|
143 |
+
CONTROLLER_URL = args.controller_url
|
144 |
+
|
145 |
+
with gr.Blocks(title="Emu", theme=gr.themes.Default(primary_hue="blue", secondary_hue="blue")) as demo:
|
146 |
+
state = gr.State()
|
147 |
+
|
148 |
+
with gr.Row():
|
149 |
+
with gr.Column(scale=2):
|
150 |
+
with gr.Row():
|
151 |
+
imagebox = gr.Image(type="pil")
|
152 |
+
with gr.Row():
|
153 |
+
videobox = gr.Video()
|
154 |
+
|
155 |
+
with gr.Accordion("Parameters", open=True, visible=True) as parameter_row:
|
156 |
+
do_sample = gr.Checkbox(value=False, label="Do Sample", interactive=True)
|
157 |
+
max_new_tokens = gr.Slider(minimum=0, maximum=2048, value=512, step=1, interactive=True, label="Max Output Tokens")
|
158 |
+
temperature = gr.Slider(minimum=0, maximum=1, value=0.7, step=0.05, interactive=True, label="Temperature")
|
159 |
+
top_k = gr.Slider(minimum=1, maximum=5, value=3, step=1, interactive=True, label="Top K")
|
160 |
+
top_p = gr.Slider(minimum=0, maximum=1, value=0.9, step=0.05, interactive=True, label="Top P")
|
161 |
+
length_penalty = gr.Slider(minimum=0, maximum=5, value=3, step=0.1, interactive=True, label="Length Penalty")
|
162 |
+
num_beams = gr.Slider(minimum=1, maximum=10, value=5, step=1, interactive=True, label="Beam Size")
|
163 |
+
repetition_penalty = gr.Slider(minimum=1.0, maximum=10.0, value=1.0, step=0.5, interactive=True, label="Repetition Penalty")
|
164 |
+
num_frames = gr.Number(interactive=True, value=8, maximum=12, label="Num Video Frames")
|
165 |
+
|
166 |
+
with gr.Column(scale=6):
|
167 |
+
chatbot = gr.Chatbot(
|
168 |
+
elem_id="chatbot",
|
169 |
+
label="Emu Chatbot",
|
170 |
+
visible=True,
|
171 |
+
height=1070,
|
172 |
+
)
|
173 |
+
|
174 |
+
with gr.Row():
|
175 |
+
with gr.Column(scale=8):
|
176 |
+
textbox = gr.Textbox(
|
177 |
+
show_label=False,
|
178 |
+
placeholder="Enter text and add to prompt",
|
179 |
+
visible=True,
|
180 |
+
container=False,
|
181 |
+
)
|
182 |
+
|
183 |
+
with gr.Column(scale=1, min_width=60):
|
184 |
+
add_btn = gr.Button(value="Add")
|
185 |
+
|
186 |
+
with gr.Row(visible=True) as button_row:
|
187 |
+
# upvote_btn = gr.Button(value="👍 Upvote", interactive=False)
|
188 |
+
# downvote_btn = gr.Button(value="👎 Downvote", interactive=False)
|
189 |
+
# flag_btn = gr.Button(value="⚠️ Flag", interactive=False)
|
190 |
+
# regenerate_btn = gr.Button(value="🔄 Regenerate", interactive=False)
|
191 |
+
clear_btn = gr.Button(value="🗑️ Clear History")
|
192 |
+
generate_btn = gr.Button(value="Generate")
|
193 |
+
|
194 |
+
|
195 |
+
clear_btn.click(clear_history, inputs=state, outputs=[chatbot, state])
|
196 |
+
textbox.submit(
|
197 |
+
submit,
|
198 |
+
inputs=[
|
199 |
+
state,
|
200 |
+
imagebox,
|
201 |
+
videobox,
|
202 |
+
textbox,
|
203 |
+
num_frames,
|
204 |
+
],
|
205 |
+
outputs=[
|
206 |
+
chatbot,
|
207 |
+
state,
|
208 |
+
imagebox,
|
209 |
+
videobox,
|
210 |
+
textbox,
|
211 |
+
],
|
212 |
+
)
|
213 |
+
add_btn.click(
|
214 |
+
submit,
|
215 |
+
inputs=[
|
216 |
+
state,
|
217 |
+
imagebox,
|
218 |
+
videobox,
|
219 |
+
textbox,
|
220 |
+
num_frames,
|
221 |
+
],
|
222 |
+
outputs=[
|
223 |
+
chatbot,
|
224 |
+
state,
|
225 |
+
imagebox,
|
226 |
+
videobox,
|
227 |
+
textbox,
|
228 |
+
],
|
229 |
+
)
|
230 |
+
generate_btn.click(
|
231 |
+
generate,
|
232 |
+
inputs=[
|
233 |
+
state,
|
234 |
+
do_sample,
|
235 |
+
max_new_tokens,
|
236 |
+
temperature,
|
237 |
+
top_k,
|
238 |
+
top_p,
|
239 |
+
length_penalty,
|
240 |
+
num_beams,
|
241 |
+
repetition_penalty,
|
242 |
+
],
|
243 |
+
outputs=[
|
244 |
+
chatbot,
|
245 |
+
state,
|
246 |
+
],
|
247 |
+
)
|
248 |
+
|
249 |
+
return demo
|
demo/generation_frontend.py
ADDED
@@ -0,0 +1,247 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# -*- coding: utf-8 -*-
|
2 |
+
|
3 |
+
# ===================================================
|
4 |
+
#
|
5 |
+
# Author : Fan Zhang
|
6 |
+
# Email : zhangfan@baai.ac.cn
|
7 |
+
# Institute : Beijing Academy of Artificial Intelligence (BAAI)
|
8 |
+
# Create On : 2023-12-11 15:35
|
9 |
+
# Last Modified : 2023-12-19 15:02
|
10 |
+
# File Name : generation_frontend.py
|
11 |
+
# Description :
|
12 |
+
#
|
13 |
+
# ===================================================
|
14 |
+
|
15 |
+
import base64
|
16 |
+
import json
|
17 |
+
import io
|
18 |
+
import time
|
19 |
+
from PIL import Image
|
20 |
+
import requests
|
21 |
+
|
22 |
+
import gradio as gr
|
23 |
+
|
24 |
+
from emu.constants import EVA_IMAGE_SIZE
|
25 |
+
from .meta import ConvMeta, Role, DataMeta
|
26 |
+
from .utils import frontend_logger as logging
|
27 |
+
|
28 |
+
CONTROLLER_URL = ""
|
29 |
+
|
30 |
+
def submit(
|
31 |
+
meta,
|
32 |
+
enable_grd,
|
33 |
+
left,
|
34 |
+
top,
|
35 |
+
right,
|
36 |
+
bottom,
|
37 |
+
image,
|
38 |
+
text,
|
39 |
+
):
|
40 |
+
if meta is None:
|
41 |
+
meta = ConvMeta()
|
42 |
+
|
43 |
+
meta.pop_error()
|
44 |
+
if meta.has_gen:
|
45 |
+
meta.clear()
|
46 |
+
|
47 |
+
if enable_grd:
|
48 |
+
if text == "" and image is None:
|
49 |
+
logging.info(f"{meta.log_id}: invalid input: no valid data for grounding input")
|
50 |
+
gr.Error("text or image must be given if enable grounding generation")
|
51 |
+
return meta.format_chatbot(), meta, False, 0, 0, EVA_IMAGE_SIZE, EVA_IMAGE_SIZE, None, ""
|
52 |
+
|
53 |
+
meta.append(Role.USER, DataMeta.build(text=text, image=image, coordinate=[left, top, right, bottom]))
|
54 |
+
elif image is not None and text != "":
|
55 |
+
logging.info(f"{meta.log_id}: invalid input: give text and image simultaneously for single modality input")
|
56 |
+
gr.Error("Do not submit text and image data at the same time!!!")
|
57 |
+
return meta.format_chatbot(), meta, False, 0, 0, EVA_IMAGE_SIZE, EVA_IMAGE_SIZE, None, ""
|
58 |
+
elif image is not None:
|
59 |
+
meta.append(Role.USER, DataMeta.build(image=image))
|
60 |
+
elif text != "":
|
61 |
+
meta.append(Role.USER, DataMeta.build(text=text))
|
62 |
+
return meta.format_chatbot(), meta, False, 0, 0, EVA_IMAGE_SIZE, EVA_IMAGE_SIZE, None, ""
|
63 |
+
|
64 |
+
|
65 |
+
def clear_history(meta):
|
66 |
+
if meta is None:
|
67 |
+
meta = ConvMeta()
|
68 |
+
meta.clear()
|
69 |
+
return meta.format_chatbot(), meta
|
70 |
+
|
71 |
+
|
72 |
+
def generate(meta, classifier_free_guidance, steps):
|
73 |
+
if meta is None:
|
74 |
+
meta = ConvMeta()
|
75 |
+
|
76 |
+
meta.pop_error()
|
77 |
+
meta.pop()
|
78 |
+
prompt = meta.format_prompt()
|
79 |
+
|
80 |
+
prompt_list, image_list = [], {}
|
81 |
+
for idx, p in enumerate(prompt):
|
82 |
+
if isinstance(p, Image.Image):
|
83 |
+
key = f"[<IMAGE{idx}>]"
|
84 |
+
prompt_list.append(["IMAGE", key])
|
85 |
+
|
86 |
+
buf = io.BytesIO()
|
87 |
+
p.save(buf, format="PNG")
|
88 |
+
image_list[key] = (key, io.BytesIO(buf.getvalue()), "image/png")
|
89 |
+
else:
|
90 |
+
prompt_list.append(["TEXT", p])
|
91 |
+
|
92 |
+
|
93 |
+
if len(image_list) == 0:
|
94 |
+
image_list = None
|
95 |
+
|
96 |
+
logging.info(f"{meta.log_id}: construct generation reqeust with prompt {prompt_list}")
|
97 |
+
|
98 |
+
t0 = time.time()
|
99 |
+
try:
|
100 |
+
rsp = requests.post(
|
101 |
+
CONTROLLER_URL + "/v1/mmg",
|
102 |
+
files=image_list,
|
103 |
+
data={
|
104 |
+
"log_id": meta.log_id,
|
105 |
+
"prompt": json.dumps(prompt_list),
|
106 |
+
"classifier_free_guidance": classifier_free_guidance,
|
107 |
+
"steps": steps,
|
108 |
+
},
|
109 |
+
)
|
110 |
+
except:
|
111 |
+
rsp = requests.Response()
|
112 |
+
rsp.status_code = 1099
|
113 |
+
t1 = time.time()
|
114 |
+
|
115 |
+
logging.info(f"{meta.log_id}: get response with status code: {rsp.status_code}, time: {(t1-t0)*1000:.3f}ms")
|
116 |
+
|
117 |
+
if rsp.ok:
|
118 |
+
content = json.loads(rsp.text)
|
119 |
+
if content["code"] == 0:
|
120 |
+
image = Image.open(io.BytesIO(base64.b64decode(content["data"])))
|
121 |
+
meta.append(Role.ASSISTANT, DataMeta.build(image=image, resize=False))
|
122 |
+
else:
|
123 |
+
meta.append(Role.ASSISTANT, DataMeta.build(text=f"GENERATE FAILED: {content['data']}"))
|
124 |
+
else:
|
125 |
+
meta.append(Role.ASSISTANT, DataMeta.build(text=f"GENERATE FAILED: http failed with code {rsp.status_code}"))
|
126 |
+
|
127 |
+
return meta.format_chatbot(), meta
|
128 |
+
|
129 |
+
|
130 |
+
def build_generation(args):
|
131 |
+
global CONTROLLER_URL
|
132 |
+
CONTROLLER_URL = args.controller_url
|
133 |
+
|
134 |
+
with gr.Blocks(title="Emu", theme=gr.themes.Default(primary_hue="blue", secondary_hue="blue")) as demo:
|
135 |
+
state = gr.State()
|
136 |
+
|
137 |
+
with gr.Row():
|
138 |
+
with gr.Column(scale=2):
|
139 |
+
with gr.Row():
|
140 |
+
imagebox = gr.Image(type="pil")
|
141 |
+
|
142 |
+
with gr.Row():
|
143 |
+
with gr.Accordion("Grounding Parameters", open=True, visible=True) as grounding_row:
|
144 |
+
enable_grd = gr.Checkbox(label="Enable")
|
145 |
+
left = gr.Slider(minimum=0, maximum=EVA_IMAGE_SIZE, value=0, step=1, interactive=True, label="left")
|
146 |
+
top = gr.Slider(minimum=0, maximum=EVA_IMAGE_SIZE, value=0, step=1, interactive=True, label="top")
|
147 |
+
right = gr.Slider(minimum=0, maximum=EVA_IMAGE_SIZE, value=EVA_IMAGE_SIZE, step=1, interactive=True, label="right")
|
148 |
+
bottom = gr.Slider(minimum=0, maximum=EVA_IMAGE_SIZE, value=EVA_IMAGE_SIZE, step=1, interactive=True, label="bottom")
|
149 |
+
|
150 |
+
with gr.Row():
|
151 |
+
with gr.Accordion("Diffusion Parameters", open=True, visible=True) as parameters_row:
|
152 |
+
cfg = gr.Slider(minimum=1, maximum=30, value=3, step=0.5, interactive=True, label="classifier free guidance")
|
153 |
+
steps = gr.Slider(minimum=1, maximum=100, value=50, step=1, interactive=True, label="steps")
|
154 |
+
|
155 |
+
with gr.Column(scale=6):
|
156 |
+
chatbot = gr.Chatbot(
|
157 |
+
elem_id="chatbot",
|
158 |
+
label="Emu Chatbot",
|
159 |
+
visible=True,
|
160 |
+
height=720,
|
161 |
+
)
|
162 |
+
|
163 |
+
with gr.Row():
|
164 |
+
with gr.Column(scale=8):
|
165 |
+
textbox = gr.Textbox(
|
166 |
+
show_label=False,
|
167 |
+
placeholder="Enter text and add to prompt",
|
168 |
+
visible=True,
|
169 |
+
container=False,
|
170 |
+
)
|
171 |
+
|
172 |
+
with gr.Column(scale=1, min_width=60):
|
173 |
+
add_btn = gr.Button(value="Add")
|
174 |
+
|
175 |
+
with gr.Row(visible=True) as button_row:
|
176 |
+
# upvote_btn = gr.Button(value="👍 Upvote", interactive=False)
|
177 |
+
# downvote_btn = gr.Button(value="👎 Downvote", interactive=False)
|
178 |
+
# regenerate_btn = gr.Button(value="🔄 Regenerate", interactive=False)
|
179 |
+
clear_btn = gr.Button(value="🗑️ Clear History")
|
180 |
+
generate_btn = gr.Button(value="Generate")
|
181 |
+
|
182 |
+
clear_btn.click(clear_history, inputs=state, outputs=[chatbot, state])
|
183 |
+
|
184 |
+
textbox.submit(
|
185 |
+
submit,
|
186 |
+
inputs=[
|
187 |
+
state,
|
188 |
+
enable_grd,
|
189 |
+
left,
|
190 |
+
top,
|
191 |
+
right,
|
192 |
+
bottom,
|
193 |
+
imagebox,
|
194 |
+
textbox,
|
195 |
+
],
|
196 |
+
outputs=[
|
197 |
+
chatbot,
|
198 |
+
state,
|
199 |
+
enable_grd,
|
200 |
+
left,
|
201 |
+
top,
|
202 |
+
right,
|
203 |
+
bottom,
|
204 |
+
imagebox,
|
205 |
+
textbox,
|
206 |
+
],
|
207 |
+
)
|
208 |
+
|
209 |
+
add_btn.click(
|
210 |
+
submit,
|
211 |
+
inputs=[
|
212 |
+
state,
|
213 |
+
enable_grd,
|
214 |
+
left,
|
215 |
+
top,
|
216 |
+
right,
|
217 |
+
bottom,
|
218 |
+
imagebox,
|
219 |
+
textbox,
|
220 |
+
],
|
221 |
+
outputs=[
|
222 |
+
chatbot,
|
223 |
+
state,
|
224 |
+
enable_grd,
|
225 |
+
left,
|
226 |
+
top,
|
227 |
+
right,
|
228 |
+
bottom,
|
229 |
+
imagebox,
|
230 |
+
textbox,
|
231 |
+
],
|
232 |
+
)
|
233 |
+
|
234 |
+
generate_btn.click(
|
235 |
+
generate,
|
236 |
+
inputs=[
|
237 |
+
state,
|
238 |
+
cfg,
|
239 |
+
steps,
|
240 |
+
],
|
241 |
+
outputs=[
|
242 |
+
chatbot,
|
243 |
+
state,
|
244 |
+
]
|
245 |
+
)
|
246 |
+
|
247 |
+
return demo
|
demo/meta.py
ADDED
@@ -0,0 +1,298 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# -*- coding: utf-8 -*-
|
2 |
+
|
3 |
+
# ===========================================================================================
|
4 |
+
#
|
5 |
+
# Copyright (c) Beijing Academy of Artificial Intelligence (BAAI). All rights reserved.
|
6 |
+
#
|
7 |
+
# Author : Fan Zhang
|
8 |
+
# Email : zhangfan@baai.ac.cn
|
9 |
+
# Institute : Beijing Academy of Artificial Intelligence (BAAI)
|
10 |
+
# Create On : 2023-12-12 02:54
|
11 |
+
# Last Modified : 2023-12-19 15:00
|
12 |
+
# File Name : meta.py
|
13 |
+
# Description :
|
14 |
+
#
|
15 |
+
# ===========================================================================================
|
16 |
+
|
17 |
+
import base64
|
18 |
+
from dataclasses import dataclass, field
|
19 |
+
import io
|
20 |
+
from enum import Enum
|
21 |
+
from PIL import Image
|
22 |
+
from typing import List, Tuple
|
23 |
+
|
24 |
+
import cv2
|
25 |
+
import numpy as np
|
26 |
+
|
27 |
+
from emu.constants import EVA_IMAGE_SIZE, GRD_SYMBOL, BOP_SYMBOL, EOP_SYMBOL, BOO_SYMBOL, EOO_SYMBOL
|
28 |
+
from emu.constants import DEFAULT_VIDEO_TOKEN, DEFAULT_EOS_TOKEN, USER_TOKEN, ASSISTANT_TOKEN, FAKE_VIDEO_END_TOKEN
|
29 |
+
|
30 |
+
from .utils import gen_id, frontend_logger as logging
|
31 |
+
|
32 |
+
|
33 |
+
class Role(Enum):
|
34 |
+
UNKNOWN = 0,
|
35 |
+
USER = 1,
|
36 |
+
ASSISTANT = 2,
|
37 |
+
|
38 |
+
|
39 |
+
class DataType(Enum):
|
40 |
+
UNKNOWN = 0,
|
41 |
+
TEXT = 1,
|
42 |
+
IMAGE = 2,
|
43 |
+
GROUNDING = 3,
|
44 |
+
VIDEO = 4,
|
45 |
+
ERROR = 5,
|
46 |
+
|
47 |
+
|
48 |
+
@dataclass
|
49 |
+
class DataMeta:
|
50 |
+
datatype: DataType = DataType.UNKNOWN
|
51 |
+
text: str = None
|
52 |
+
image: Image.Image = None
|
53 |
+
mask: Image.Image = None
|
54 |
+
coordinate: List[int] = None
|
55 |
+
frames: List[Image.Image] = None
|
56 |
+
stack_frame: Image.Image = None
|
57 |
+
|
58 |
+
@property
|
59 |
+
def grounding(self):
|
60 |
+
return self.coordinate is not None
|
61 |
+
|
62 |
+
@property
|
63 |
+
def text_str(self):
|
64 |
+
return self.text
|
65 |
+
|
66 |
+
@property
|
67 |
+
def image_str(self):
|
68 |
+
return self.image2str(self.image)
|
69 |
+
|
70 |
+
@property
|
71 |
+
def video_str(self):
|
72 |
+
ret = f'<div style="overflow:scroll"><b>[VIDEO]</b></div>{self.image2str(self.stack_frame)}'
|
73 |
+
return ret
|
74 |
+
|
75 |
+
@property
|
76 |
+
def grounding_str(self):
|
77 |
+
ret = ""
|
78 |
+
if self.text is not None:
|
79 |
+
ret += f'<div style="overflow:scroll"><b>[PHRASE]</b>{self.text}</div>'
|
80 |
+
|
81 |
+
ret += self.image2str(self.mask)
|
82 |
+
|
83 |
+
if self.image is not None:
|
84 |
+
ret += self.image2str(self.image)
|
85 |
+
return ret
|
86 |
+
|
87 |
+
def image2str(self, image):
|
88 |
+
buf = io.BytesIO()
|
89 |
+
image.save(buf, format="WEBP")
|
90 |
+
i_str = base64.b64encode(buf.getvalue()).decode()
|
91 |
+
return f'<div style="float:left"><img src="data:image/png;base64, {i_str}"></div>'
|
92 |
+
|
93 |
+
def format_chatbot(self):
|
94 |
+
match self.datatype:
|
95 |
+
case DataType.TEXT:
|
96 |
+
return self.text_str
|
97 |
+
case DataType.IMAGE:
|
98 |
+
return self.image_str
|
99 |
+
case DataType.VIDEO:
|
100 |
+
return self.video_str
|
101 |
+
case DataType.GROUNDING:
|
102 |
+
return self.grounding_str
|
103 |
+
case _:
|
104 |
+
return ""
|
105 |
+
|
106 |
+
def format_prompt(self) -> List[str | Image.Image]:
|
107 |
+
match self.datatype:
|
108 |
+
case DataType.TEXT:
|
109 |
+
return [self.text]
|
110 |
+
case DataType.IMAGE:
|
111 |
+
return [self.image]
|
112 |
+
case DataType.VIDEO:
|
113 |
+
return [DEFAULT_VIDEO_TOKEN] + self.frames + [FAKE_VIDEO_END_TOKEN]
|
114 |
+
case DataType.GROUNDING:
|
115 |
+
ret = []
|
116 |
+
if self.text is not None:
|
117 |
+
ret.append(f"{BOP_SYMBOL}{self.text}{EOP_SYMBOL}")
|
118 |
+
ret += [BOO_SYMBOL, self.mask, EOO_SYMBOL]
|
119 |
+
if self.image is not None:
|
120 |
+
ret.append(self.image)
|
121 |
+
return ret
|
122 |
+
case _:
|
123 |
+
return []
|
124 |
+
|
125 |
+
def __str__(self):
|
126 |
+
s = ""
|
127 |
+
if self.text is not None:
|
128 |
+
s += f"T:{self.text}"
|
129 |
+
|
130 |
+
if self.image is not None:
|
131 |
+
w, h = self.image.size
|
132 |
+
s += f"[I:{h}x{w}]"
|
133 |
+
|
134 |
+
if self.coordinate is not None:
|
135 |
+
l, t, r, b = self.coordinate
|
136 |
+
s += f"[C:({l:03d},{t:03d}),({r:03d},{b:03d})]"
|
137 |
+
|
138 |
+
if self.frames is not None:
|
139 |
+
w, h = self.frames[0].size
|
140 |
+
s += f"[V:{len(self.frames)}x{h}x{w}]"
|
141 |
+
|
142 |
+
return s
|
143 |
+
|
144 |
+
@classmethod
|
145 |
+
def build(cls, text=None, image=None, coordinate=None, frames=None, is_error=False, *, resize: bool = True):
|
146 |
+
ins = cls()
|
147 |
+
ins.text = text if text != "" else None
|
148 |
+
ins.image = cls.resize(image, force=resize)
|
149 |
+
# ins.image = image
|
150 |
+
ins.coordinate = cls.fix(coordinate)
|
151 |
+
ins.frames = cls.resize(frames, force=resize)
|
152 |
+
# ins.frames = frames
|
153 |
+
|
154 |
+
if is_error:
|
155 |
+
ins.datatype = DataType.ERROR
|
156 |
+
elif coordinate is not None:
|
157 |
+
ins.datatype = DataType.GROUNDING
|
158 |
+
ins.draw_box()
|
159 |
+
elif image is not None:
|
160 |
+
ins.datatype = DataType.IMAGE
|
161 |
+
elif text is not None:
|
162 |
+
ins.datatype = DataType.TEXT
|
163 |
+
else:
|
164 |
+
ins.datatype = DataType.VIDEO
|
165 |
+
ins.stack()
|
166 |
+
|
167 |
+
return ins
|
168 |
+
|
169 |
+
@classmethod
|
170 |
+
def fix(cls, coordinate):
|
171 |
+
if coordinate is None:
|
172 |
+
return None
|
173 |
+
|
174 |
+
l, t, r, b = coordinate
|
175 |
+
l = min(EVA_IMAGE_SIZE, max(0, l))
|
176 |
+
t = min(EVA_IMAGE_SIZE, max(0, t))
|
177 |
+
r = min(EVA_IMAGE_SIZE, max(0, r))
|
178 |
+
b = min(EVA_IMAGE_SIZE, max(0, b))
|
179 |
+
return min(l, r), min(t, b), max(l, r), max(t, b)
|
180 |
+
|
181 |
+
@classmethod
|
182 |
+
def resize(cls, image: Image.Image | List[Image.Image] | None, *, force: bool = True):
|
183 |
+
if image is None:
|
184 |
+
return None
|
185 |
+
|
186 |
+
if not force:
|
187 |
+
return image
|
188 |
+
|
189 |
+
if isinstance(image, Image.Image):
|
190 |
+
image = [image]
|
191 |
+
|
192 |
+
for idx, im in enumerate(image):
|
193 |
+
w, h = im.size
|
194 |
+
if w < h:
|
195 |
+
h = int(EVA_IMAGE_SIZE / w * h)
|
196 |
+
w = EVA_IMAGE_SIZE
|
197 |
+
else:
|
198 |
+
w = int(EVA_IMAGE_SIZE / h * w)
|
199 |
+
h = EVA_IMAGE_SIZE
|
200 |
+
|
201 |
+
image[idx] = im.resize((w, h))
|
202 |
+
|
203 |
+
return image if len(image) > 1 else image[0]
|
204 |
+
|
205 |
+
def draw_box(self):
|
206 |
+
left, top, right, bottom = self.coordinate
|
207 |
+
mask = np.zeros((EVA_IMAGE_SIZE, EVA_IMAGE_SIZE, 3), dtype=np.uint8)
|
208 |
+
mask = cv2.rectangle(mask, (left, top), (right, bottom), (255, 255, 255), 3)
|
209 |
+
self.mask = Image.fromarray(mask)
|
210 |
+
|
211 |
+
def stack(self):
|
212 |
+
w, h = self.frames[0].size
|
213 |
+
n = len(self.frames)
|
214 |
+
stack_frame = Image.new(mode="RGB", size=(w*n, h))
|
215 |
+
for idx, f in enumerate(self.frames):
|
216 |
+
stack_frame.paste(f, (idx*w, 0))
|
217 |
+
self.stack_frame = stack_frame
|
218 |
+
|
219 |
+
|
220 |
+
class ConvMeta:
|
221 |
+
|
222 |
+
def __init__(self):
|
223 |
+
self.system: str = "You are a helpful assistant, dedicated to delivering comprehensive and meticulous responses."
|
224 |
+
self.message: List[Tuple[Role, DataMeta]] = []
|
225 |
+
self.log_id: str = gen_id()
|
226 |
+
|
227 |
+
logging.info(f"{self.log_id}: create new round of chat")
|
228 |
+
|
229 |
+
def append(self, r: Role, p: DataMeta):
|
230 |
+
logging.info(f"{self.log_id}: APPEND [{r.name}] prompt element, type: {p.datatype.name}, message: {p}")
|
231 |
+
self.message.append((r, p))
|
232 |
+
|
233 |
+
def format_chatbot(self):
|
234 |
+
ret = []
|
235 |
+
for r, p in self.message:
|
236 |
+
cur_p = p.format_chatbot()
|
237 |
+
if r == Role.USER:
|
238 |
+
ret.append((cur_p, None))
|
239 |
+
else:
|
240 |
+
ret.append((None, cur_p))
|
241 |
+
return ret
|
242 |
+
|
243 |
+
def format_prompt(self):
|
244 |
+
ret = []
|
245 |
+
has_coor = False
|
246 |
+
for _, p in self.message:
|
247 |
+
has_coor |= (p.datatype == DataType.GROUNDING)
|
248 |
+
ret += p.format_prompt()
|
249 |
+
|
250 |
+
if has_coor:
|
251 |
+
ret.insert(0, GRD_SYMBOL)
|
252 |
+
|
253 |
+
logging.info(f"{self.log_id}: format generation prompt: {ret}")
|
254 |
+
return ret
|
255 |
+
|
256 |
+
def format_chat(self):
|
257 |
+
ret = [self.system]
|
258 |
+
|
259 |
+
prev_r = None
|
260 |
+
for r, p in self.message:
|
261 |
+
if prev_r != r:
|
262 |
+
if prev_r == Role.ASSISTANT:
|
263 |
+
ret.append(f"{DEFAULT_EOS_TOKEN}{USER_TOKEN}: ")
|
264 |
+
elif prev_r is None:
|
265 |
+
ret.append(f" {USER_TOKEN}: ")
|
266 |
+
else:
|
267 |
+
ret.append(f" {ASSISTANT_TOKEN}: ")
|
268 |
+
ret += p.format_prompt()
|
269 |
+
prev_r = r
|
270 |
+
else:
|
271 |
+
ret += p.format_prompt()
|
272 |
+
|
273 |
+
ret.append(f" {ASSISTANT_TOKEN}:")
|
274 |
+
|
275 |
+
logging.info(f"{self.log_id}: format chat prompt: {ret}")
|
276 |
+
return ret
|
277 |
+
|
278 |
+
def clear(self):
|
279 |
+
logging.info(f"{self.log_id}: clear chat history, end current chat round.")
|
280 |
+
del self.message
|
281 |
+
self.message = []
|
282 |
+
self.log_id = gen_id()
|
283 |
+
|
284 |
+
def pop(self):
|
285 |
+
if self.has_gen:
|
286 |
+
logging.info(f"{self.log_id}: pop out previous generation / chat result")
|
287 |
+
self.message.pop()
|
288 |
+
|
289 |
+
def pop_error(self):
|
290 |
+
self.message = [(r, p) for r, p in self.message if p.datatype != DataType.ERROR]
|
291 |
+
|
292 |
+
@property
|
293 |
+
def has_gen(self):
|
294 |
+
if len(self.message) == 0:
|
295 |
+
return False
|
296 |
+
if self.message[-1][0] == Role.USER:
|
297 |
+
return False
|
298 |
+
return True
|
demo/utils.py
ADDED
@@ -0,0 +1,87 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# -*- coding: utf-8 -*-
|
2 |
+
|
3 |
+
# ===================================================
|
4 |
+
#
|
5 |
+
# Author : Fan Zhang
|
6 |
+
# Email : zhangfan@baai.ac.cn
|
7 |
+
# Institute : Beijing Academy of Artificial Intelligence (BAAI)
|
8 |
+
# Create On : 2023-12-13 09:48
|
9 |
+
# Last Modified : 2023-12-14 01:53
|
10 |
+
# File Name : utils.py
|
11 |
+
# Description :
|
12 |
+
#
|
13 |
+
# ===================================================
|
14 |
+
|
15 |
+
from datetime import datetime
|
16 |
+
import logging
|
17 |
+
import logging.config
|
18 |
+
import hashlib
|
19 |
+
import os.path as osp
|
20 |
+
import uuid
|
21 |
+
from PIL import Image
|
22 |
+
|
23 |
+
from decord import VideoReader
|
24 |
+
|
25 |
+
|
26 |
+
def config_logger(logger_name):
|
27 |
+
logger_config = {
|
28 |
+
"version": 1,
|
29 |
+
"formatters": {
|
30 |
+
"standard": {
|
31 |
+
"format": "%(asctime)s - %(filename)s: %(lineno)d - [%(levelname)s] - %(message)s",
|
32 |
+
"datefmt": "%Y-%m-%d %H:%M:%S",
|
33 |
+
},
|
34 |
+
},
|
35 |
+
"handlers": {
|
36 |
+
"console": {
|
37 |
+
"class": "logging.StreamHandler",
|
38 |
+
"formatter": "standard",
|
39 |
+
"level": "INFO",
|
40 |
+
},
|
41 |
+
"file": {
|
42 |
+
"class": "logging.handlers.TimedRotatingFileHandler",
|
43 |
+
"filename": osp.join(osp.dirname(__file__), "..", "log", f"{logger_name}.log"),
|
44 |
+
"formatter": "standard",
|
45 |
+
"level": "INFO",
|
46 |
+
"when": "D",
|
47 |
+
"interval": 7,
|
48 |
+
"backupCount": 90,
|
49 |
+
},
|
50 |
+
},
|
51 |
+
"loggers": {
|
52 |
+
logger_name: {
|
53 |
+
"handlers": ["file", "console"],
|
54 |
+
"level": "INFO",
|
55 |
+
"propagate": True,
|
56 |
+
},
|
57 |
+
},
|
58 |
+
}
|
59 |
+
|
60 |
+
logging.config.dictConfig(logger_config)
|
61 |
+
logger = logging.getLogger(logger_name)
|
62 |
+
return logger
|
63 |
+
|
64 |
+
frontend_logger = config_logger("Emu-v2_frontend")
|
65 |
+
beckend_logger = config_logger("Emu-v2_backend")
|
66 |
+
|
67 |
+
|
68 |
+
def extract_frames(video, num_frames):
|
69 |
+
video = VideoReader(video)
|
70 |
+
total_frames = len(video)
|
71 |
+
segment = int(total_frames // num_frames)
|
72 |
+
|
73 |
+
frames = video.get_batch(list(range(int(segment//2), total_frames, segment))).asnumpy()
|
74 |
+
frames = [Image.fromarray(f) for f in frames]
|
75 |
+
return frames
|
76 |
+
|
77 |
+
|
78 |
+
def image2md5(image: Image.Image):
|
79 |
+
md5hash = hashlib.md5(image.tobytes())
|
80 |
+
return md5hash.hexdigest()
|
81 |
+
|
82 |
+
|
83 |
+
def gen_id():
|
84 |
+
logid = datetime.now().strftime("%Y%m%d%H%M%d")
|
85 |
+
logid += f"{uuid.uuid4().hex}"
|
86 |
+
return logid
|
87 |
+
|
requirements.txt
ADDED
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
gradio
|
2 |
+
pillow
|
3 |
+
numpy
|
4 |
+
opencv-python
|
5 |
+
decord
|