ryanzhangfan commited on
Commit
9aa6aea
1 Parent(s): f8f41f8

initial commit

Browse files
app.py ADDED
@@ -0,0 +1,66 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+
3
+ # ===================================================
4
+ #
5
+ # Author : Fan Zhang
6
+ # Email : zhangfan@baai.ac.cn
7
+ # Institute : Beijing Academy of Artificial Intelligence (BAAI)
8
+ # Create On : 2023-12-11 15:34
9
+ # Last Modified : 2023-12-20 03:59
10
+ # File Name : frontend.py
11
+ # Description :
12
+ #
13
+ # ===================================================
14
+
15
+ import argparse
16
+ import os
17
+
18
+ import gradio as gr
19
+ from demo.generation_frontend import build_generation
20
+ from demo.chat_frontend import build_chat
21
+
22
+ parser = argparse.ArgumentParser()
23
+ parser.add_argument("--title", type=str, default='Emu')
24
+
25
+ parser.add_argument("--host", type=str, default="0.0.0.0")
26
+ parser.add_argument("--port", type=int, default=9002)
27
+ parser.add_argument("--share", action="store_true")
28
+ parser.add_argument("--controller-url", type=str, default="http://218.91.113.230:9002")
29
+ parser.add_argument("--concurrency-count", type=int, default=8)
30
+ parser.add_argument("--disable-chat", action="store_true")
31
+ parser.add_argument("--disable-generate", action="store_true")
32
+
33
+ args = parser.parse_args()
34
+
35
+
36
+ if __name__ == "__main__":
37
+ title = "EmuV2: An Open Multimodal Generalist"
38
+ os.makedirs("log", exist_ok=True)
39
+
40
+ interface_list, tab_names = [], []
41
+ if not args.disable_generate:
42
+ demo_generation = build_generation(args)
43
+ interface_list.append(demo_generation)
44
+ tab_names.append("Multi-modal Generation")
45
+
46
+ if not args.disable_chat:
47
+ demo_chat = build_chat(args)
48
+ interface_list.append(demo_chat)
49
+ tab_names.append("Multi-modal Chat")
50
+
51
+ demo_all = gr.TabbedInterface(
52
+ interface_list=interface_list,
53
+ tab_names=tab_names,
54
+ title=title,
55
+ theme=gr.themes.Default(primary_hue="blue", secondary_hue="blue"),
56
+ )
57
+
58
+ demo_all.queue(
59
+ concurrency_count=args.concurrency_count,
60
+ status_update_rate=3,
61
+ api_open=False,
62
+ ).launch(
63
+ enable_queue=True,
64
+ server_name=args.host, server_port=args.port,
65
+ share=args.share,
66
+ )
demo/__init__.py ADDED
File without changes
demo/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (163 Bytes). View file
 
demo/__pycache__/chat_frontend.cpython-310.pyc ADDED
Binary file (4.93 kB). View file
 
demo/__pycache__/generation_frontend.cpython-310.pyc ADDED
Binary file (5.03 kB). View file
 
demo/__pycache__/meta.cpython-310.pyc ADDED
Binary file (8.26 kB). View file
 
demo/__pycache__/utils.cpython-310.pyc ADDED
Binary file (2.06 kB). View file
 
demo/chat_frontend.py ADDED
@@ -0,0 +1,249 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+
3
+ # ===================================================
4
+ #
5
+ # Author : Fan Zhang
6
+ # Email : zhangfan@baai.ac.cn
7
+ # Institute : Beijing Academy of Artificial Intelligence (BAAI)
8
+ # Create On : 2023-12-12 18:05
9
+ # Last Modified : 2023-12-19 15:00
10
+ # File Name : chat_frontend.py
11
+ # Description :
12
+ #
13
+ # ===================================================
14
+
15
+ import json
16
+ import io
17
+ import time
18
+ from PIL import Image
19
+ import requests
20
+
21
+ import gradio as gr
22
+
23
+ from .meta import ConvMeta, Role, DataMeta
24
+ from .utils import extract_frames
25
+ from .utils import frontend_logger as logging
26
+
27
+ CONTROLLER_URL = ""
28
+
29
+ def submit(
30
+ meta,
31
+ image,
32
+ video,
33
+ text,
34
+ num_frames,
35
+ ):
36
+ if meta is None:
37
+ meta = ConvMeta()
38
+
39
+ meta.pop_error()
40
+
41
+ check_text = (text != "" and text is not None)
42
+ check_image = image is not None
43
+ check_video = video is not None
44
+
45
+ if check_text + check_image + check_video != 1:
46
+ logging.info(f"{meta.log_id}: invalid input: give multi madality simultaneously for single modality input")
47
+ gr.Error("Invalid input number, must give exactly one modality input at a time")
48
+ return meta.format_chatbot(), meta, None, None, ""
49
+
50
+ if check_text:
51
+ meta.append(Role.USER, DataMeta.build(text=text))
52
+ elif check_image:
53
+ meta.append(Role.USER, DataMeta.build(image=image))
54
+ elif check_video:
55
+ frames = extract_frames(video, num_frames)
56
+ meta.append(Role.USER, DataMeta.build(frames=frames))
57
+
58
+ return meta.format_chatbot(), meta, None, None, ""
59
+
60
+
61
+ def clear_history(meta):
62
+ if meta is None:
63
+ meta = ConvMeta()
64
+ meta.clear()
65
+ return meta.format_chatbot(), meta
66
+
67
+
68
+ def generate(
69
+ meta,
70
+ do_sample,
71
+ max_new_tokens,
72
+ temperature,
73
+ top_k,
74
+ top_p,
75
+ length_penalty,
76
+ num_beams,
77
+ repetition_penalty,
78
+ ):
79
+ if meta is None:
80
+ meta = ConvMeta()
81
+
82
+ meta.pop_error()
83
+ meta.pop()
84
+ prompt = meta.format_chat()
85
+
86
+ prompt_list, image_list = [], {}
87
+ for idx, p in enumerate(prompt):
88
+ if isinstance(p, Image.Image):
89
+ key = f"[<IMAGE{idx}>]"
90
+ prompt_list.append(["IMAGE", key])
91
+
92
+ buf = io.BytesIO()
93
+ p.save(buf, format="PNG")
94
+ image_list[key] = (key, io.BytesIO(buf.getvalue()), "image/png")
95
+ else:
96
+ prompt_list.append(["TEXT", p])
97
+
98
+ if len(image_list) == 0:
99
+ image_list = None
100
+
101
+ logging.info(f"{meta.log_id}: construct chat reqeust with prompt {prompt_list}")
102
+
103
+ t0 = time.time()
104
+ try:
105
+ print(do_sample)
106
+ rsp = requests.post(
107
+ CONTROLLER_URL + "/v1/mmc",
108
+ files=image_list,
109
+ data={
110
+ "log_id": meta.log_id,
111
+ "prompt": json.dumps(prompt_list),
112
+ "do_sample": do_sample,
113
+ "max_new_tokens": max_new_tokens,
114
+ "temperature": temperature,
115
+ "top_k": top_k,
116
+ "top_p": top_p,
117
+ "length_penalty": length_penalty,
118
+ "num_beams": num_beams,
119
+ "repetition_penalty": repetition_penalty,
120
+ },
121
+ )
122
+ except:
123
+ rsp = requests.Response()
124
+ rsp.status_code = 1099
125
+ t1 = time.time()
126
+
127
+ logging.info(f"{meta.log_id}: get response with status code: {rsp.status_code}, time: {(t1-t0)*1000:.3f}ms")
128
+
129
+ if rsp.ok:
130
+ content = json.loads(rsp.text)
131
+ if content["code"] == 0:
132
+ meta.append(Role.ASSISTANT, DataMeta.build(text=content["data"]))
133
+ else:
134
+ meta.append(Role.ASSISTANT, DataMeta.build(text=f"GENERATE FAILED: {content['data']}"), is_error=True)
135
+ else:
136
+ meta.append(Role.ASSISTANT, DataMeta.build(text=f"GENERATE FAILED: http failed with code {rsp.status_code}"), is_error=True)
137
+
138
+ return meta.format_chatbot(), meta
139
+
140
+
141
+ def build_chat(args):
142
+ global CONTROLLER_URL
143
+ CONTROLLER_URL = args.controller_url
144
+
145
+ with gr.Blocks(title="Emu", theme=gr.themes.Default(primary_hue="blue", secondary_hue="blue")) as demo:
146
+ state = gr.State()
147
+
148
+ with gr.Row():
149
+ with gr.Column(scale=2):
150
+ with gr.Row():
151
+ imagebox = gr.Image(type="pil")
152
+ with gr.Row():
153
+ videobox = gr.Video()
154
+
155
+ with gr.Accordion("Parameters", open=True, visible=True) as parameter_row:
156
+ do_sample = gr.Checkbox(value=False, label="Do Sample", interactive=True)
157
+ max_new_tokens = gr.Slider(minimum=0, maximum=2048, value=512, step=1, interactive=True, label="Max Output Tokens")
158
+ temperature = gr.Slider(minimum=0, maximum=1, value=0.7, step=0.05, interactive=True, label="Temperature")
159
+ top_k = gr.Slider(minimum=1, maximum=5, value=3, step=1, interactive=True, label="Top K")
160
+ top_p = gr.Slider(minimum=0, maximum=1, value=0.9, step=0.05, interactive=True, label="Top P")
161
+ length_penalty = gr.Slider(minimum=0, maximum=5, value=3, step=0.1, interactive=True, label="Length Penalty")
162
+ num_beams = gr.Slider(minimum=1, maximum=10, value=5, step=1, interactive=True, label="Beam Size")
163
+ repetition_penalty = gr.Slider(minimum=1.0, maximum=10.0, value=1.0, step=0.5, interactive=True, label="Repetition Penalty")
164
+ num_frames = gr.Number(interactive=True, value=8, maximum=12, label="Num Video Frames")
165
+
166
+ with gr.Column(scale=6):
167
+ chatbot = gr.Chatbot(
168
+ elem_id="chatbot",
169
+ label="Emu Chatbot",
170
+ visible=True,
171
+ height=1070,
172
+ )
173
+
174
+ with gr.Row():
175
+ with gr.Column(scale=8):
176
+ textbox = gr.Textbox(
177
+ show_label=False,
178
+ placeholder="Enter text and add to prompt",
179
+ visible=True,
180
+ container=False,
181
+ )
182
+
183
+ with gr.Column(scale=1, min_width=60):
184
+ add_btn = gr.Button(value="Add")
185
+
186
+ with gr.Row(visible=True) as button_row:
187
+ # upvote_btn = gr.Button(value="👍 Upvote", interactive=False)
188
+ # downvote_btn = gr.Button(value="👎 Downvote", interactive=False)
189
+ # flag_btn = gr.Button(value="⚠️ Flag", interactive=False)
190
+ # regenerate_btn = gr.Button(value="🔄 Regenerate", interactive=False)
191
+ clear_btn = gr.Button(value="🗑️ Clear History")
192
+ generate_btn = gr.Button(value="Generate")
193
+
194
+
195
+ clear_btn.click(clear_history, inputs=state, outputs=[chatbot, state])
196
+ textbox.submit(
197
+ submit,
198
+ inputs=[
199
+ state,
200
+ imagebox,
201
+ videobox,
202
+ textbox,
203
+ num_frames,
204
+ ],
205
+ outputs=[
206
+ chatbot,
207
+ state,
208
+ imagebox,
209
+ videobox,
210
+ textbox,
211
+ ],
212
+ )
213
+ add_btn.click(
214
+ submit,
215
+ inputs=[
216
+ state,
217
+ imagebox,
218
+ videobox,
219
+ textbox,
220
+ num_frames,
221
+ ],
222
+ outputs=[
223
+ chatbot,
224
+ state,
225
+ imagebox,
226
+ videobox,
227
+ textbox,
228
+ ],
229
+ )
230
+ generate_btn.click(
231
+ generate,
232
+ inputs=[
233
+ state,
234
+ do_sample,
235
+ max_new_tokens,
236
+ temperature,
237
+ top_k,
238
+ top_p,
239
+ length_penalty,
240
+ num_beams,
241
+ repetition_penalty,
242
+ ],
243
+ outputs=[
244
+ chatbot,
245
+ state,
246
+ ],
247
+ )
248
+
249
+ return demo
demo/generation_frontend.py ADDED
@@ -0,0 +1,247 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+
3
+ # ===================================================
4
+ #
5
+ # Author : Fan Zhang
6
+ # Email : zhangfan@baai.ac.cn
7
+ # Institute : Beijing Academy of Artificial Intelligence (BAAI)
8
+ # Create On : 2023-12-11 15:35
9
+ # Last Modified : 2023-12-19 15:02
10
+ # File Name : generation_frontend.py
11
+ # Description :
12
+ #
13
+ # ===================================================
14
+
15
+ import base64
16
+ import json
17
+ import io
18
+ import time
19
+ from PIL import Image
20
+ import requests
21
+
22
+ import gradio as gr
23
+
24
+ from emu.constants import EVA_IMAGE_SIZE
25
+ from .meta import ConvMeta, Role, DataMeta
26
+ from .utils import frontend_logger as logging
27
+
28
+ CONTROLLER_URL = ""
29
+
30
+ def submit(
31
+ meta,
32
+ enable_grd,
33
+ left,
34
+ top,
35
+ right,
36
+ bottom,
37
+ image,
38
+ text,
39
+ ):
40
+ if meta is None:
41
+ meta = ConvMeta()
42
+
43
+ meta.pop_error()
44
+ if meta.has_gen:
45
+ meta.clear()
46
+
47
+ if enable_grd:
48
+ if text == "" and image is None:
49
+ logging.info(f"{meta.log_id}: invalid input: no valid data for grounding input")
50
+ gr.Error("text or image must be given if enable grounding generation")
51
+ return meta.format_chatbot(), meta, False, 0, 0, EVA_IMAGE_SIZE, EVA_IMAGE_SIZE, None, ""
52
+
53
+ meta.append(Role.USER, DataMeta.build(text=text, image=image, coordinate=[left, top, right, bottom]))
54
+ elif image is not None and text != "":
55
+ logging.info(f"{meta.log_id}: invalid input: give text and image simultaneously for single modality input")
56
+ gr.Error("Do not submit text and image data at the same time!!!")
57
+ return meta.format_chatbot(), meta, False, 0, 0, EVA_IMAGE_SIZE, EVA_IMAGE_SIZE, None, ""
58
+ elif image is not None:
59
+ meta.append(Role.USER, DataMeta.build(image=image))
60
+ elif text != "":
61
+ meta.append(Role.USER, DataMeta.build(text=text))
62
+ return meta.format_chatbot(), meta, False, 0, 0, EVA_IMAGE_SIZE, EVA_IMAGE_SIZE, None, ""
63
+
64
+
65
+ def clear_history(meta):
66
+ if meta is None:
67
+ meta = ConvMeta()
68
+ meta.clear()
69
+ return meta.format_chatbot(), meta
70
+
71
+
72
+ def generate(meta, classifier_free_guidance, steps):
73
+ if meta is None:
74
+ meta = ConvMeta()
75
+
76
+ meta.pop_error()
77
+ meta.pop()
78
+ prompt = meta.format_prompt()
79
+
80
+ prompt_list, image_list = [], {}
81
+ for idx, p in enumerate(prompt):
82
+ if isinstance(p, Image.Image):
83
+ key = f"[<IMAGE{idx}>]"
84
+ prompt_list.append(["IMAGE", key])
85
+
86
+ buf = io.BytesIO()
87
+ p.save(buf, format="PNG")
88
+ image_list[key] = (key, io.BytesIO(buf.getvalue()), "image/png")
89
+ else:
90
+ prompt_list.append(["TEXT", p])
91
+
92
+
93
+ if len(image_list) == 0:
94
+ image_list = None
95
+
96
+ logging.info(f"{meta.log_id}: construct generation reqeust with prompt {prompt_list}")
97
+
98
+ t0 = time.time()
99
+ try:
100
+ rsp = requests.post(
101
+ CONTROLLER_URL + "/v1/mmg",
102
+ files=image_list,
103
+ data={
104
+ "log_id": meta.log_id,
105
+ "prompt": json.dumps(prompt_list),
106
+ "classifier_free_guidance": classifier_free_guidance,
107
+ "steps": steps,
108
+ },
109
+ )
110
+ except:
111
+ rsp = requests.Response()
112
+ rsp.status_code = 1099
113
+ t1 = time.time()
114
+
115
+ logging.info(f"{meta.log_id}: get response with status code: {rsp.status_code}, time: {(t1-t0)*1000:.3f}ms")
116
+
117
+ if rsp.ok:
118
+ content = json.loads(rsp.text)
119
+ if content["code"] == 0:
120
+ image = Image.open(io.BytesIO(base64.b64decode(content["data"])))
121
+ meta.append(Role.ASSISTANT, DataMeta.build(image=image, resize=False))
122
+ else:
123
+ meta.append(Role.ASSISTANT, DataMeta.build(text=f"GENERATE FAILED: {content['data']}"))
124
+ else:
125
+ meta.append(Role.ASSISTANT, DataMeta.build(text=f"GENERATE FAILED: http failed with code {rsp.status_code}"))
126
+
127
+ return meta.format_chatbot(), meta
128
+
129
+
130
+ def build_generation(args):
131
+ global CONTROLLER_URL
132
+ CONTROLLER_URL = args.controller_url
133
+
134
+ with gr.Blocks(title="Emu", theme=gr.themes.Default(primary_hue="blue", secondary_hue="blue")) as demo:
135
+ state = gr.State()
136
+
137
+ with gr.Row():
138
+ with gr.Column(scale=2):
139
+ with gr.Row():
140
+ imagebox = gr.Image(type="pil")
141
+
142
+ with gr.Row():
143
+ with gr.Accordion("Grounding Parameters", open=True, visible=True) as grounding_row:
144
+ enable_grd = gr.Checkbox(label="Enable")
145
+ left = gr.Slider(minimum=0, maximum=EVA_IMAGE_SIZE, value=0, step=1, interactive=True, label="left")
146
+ top = gr.Slider(minimum=0, maximum=EVA_IMAGE_SIZE, value=0, step=1, interactive=True, label="top")
147
+ right = gr.Slider(minimum=0, maximum=EVA_IMAGE_SIZE, value=EVA_IMAGE_SIZE, step=1, interactive=True, label="right")
148
+ bottom = gr.Slider(minimum=0, maximum=EVA_IMAGE_SIZE, value=EVA_IMAGE_SIZE, step=1, interactive=True, label="bottom")
149
+
150
+ with gr.Row():
151
+ with gr.Accordion("Diffusion Parameters", open=True, visible=True) as parameters_row:
152
+ cfg = gr.Slider(minimum=1, maximum=30, value=3, step=0.5, interactive=True, label="classifier free guidance")
153
+ steps = gr.Slider(minimum=1, maximum=100, value=50, step=1, interactive=True, label="steps")
154
+
155
+ with gr.Column(scale=6):
156
+ chatbot = gr.Chatbot(
157
+ elem_id="chatbot",
158
+ label="Emu Chatbot",
159
+ visible=True,
160
+ height=720,
161
+ )
162
+
163
+ with gr.Row():
164
+ with gr.Column(scale=8):
165
+ textbox = gr.Textbox(
166
+ show_label=False,
167
+ placeholder="Enter text and add to prompt",
168
+ visible=True,
169
+ container=False,
170
+ )
171
+
172
+ with gr.Column(scale=1, min_width=60):
173
+ add_btn = gr.Button(value="Add")
174
+
175
+ with gr.Row(visible=True) as button_row:
176
+ # upvote_btn = gr.Button(value="👍 Upvote", interactive=False)
177
+ # downvote_btn = gr.Button(value="👎 Downvote", interactive=False)
178
+ # regenerate_btn = gr.Button(value="🔄 Regenerate", interactive=False)
179
+ clear_btn = gr.Button(value="🗑️ Clear History")
180
+ generate_btn = gr.Button(value="Generate")
181
+
182
+ clear_btn.click(clear_history, inputs=state, outputs=[chatbot, state])
183
+
184
+ textbox.submit(
185
+ submit,
186
+ inputs=[
187
+ state,
188
+ enable_grd,
189
+ left,
190
+ top,
191
+ right,
192
+ bottom,
193
+ imagebox,
194
+ textbox,
195
+ ],
196
+ outputs=[
197
+ chatbot,
198
+ state,
199
+ enable_grd,
200
+ left,
201
+ top,
202
+ right,
203
+ bottom,
204
+ imagebox,
205
+ textbox,
206
+ ],
207
+ )
208
+
209
+ add_btn.click(
210
+ submit,
211
+ inputs=[
212
+ state,
213
+ enable_grd,
214
+ left,
215
+ top,
216
+ right,
217
+ bottom,
218
+ imagebox,
219
+ textbox,
220
+ ],
221
+ outputs=[
222
+ chatbot,
223
+ state,
224
+ enable_grd,
225
+ left,
226
+ top,
227
+ right,
228
+ bottom,
229
+ imagebox,
230
+ textbox,
231
+ ],
232
+ )
233
+
234
+ generate_btn.click(
235
+ generate,
236
+ inputs=[
237
+ state,
238
+ cfg,
239
+ steps,
240
+ ],
241
+ outputs=[
242
+ chatbot,
243
+ state,
244
+ ]
245
+ )
246
+
247
+ return demo
demo/meta.py ADDED
@@ -0,0 +1,298 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+
3
+ # ===========================================================================================
4
+ #
5
+ # Copyright (c) Beijing Academy of Artificial Intelligence (BAAI). All rights reserved.
6
+ #
7
+ # Author : Fan Zhang
8
+ # Email : zhangfan@baai.ac.cn
9
+ # Institute : Beijing Academy of Artificial Intelligence (BAAI)
10
+ # Create On : 2023-12-12 02:54
11
+ # Last Modified : 2023-12-19 15:00
12
+ # File Name : meta.py
13
+ # Description :
14
+ #
15
+ # ===========================================================================================
16
+
17
+ import base64
18
+ from dataclasses import dataclass, field
19
+ import io
20
+ from enum import Enum
21
+ from PIL import Image
22
+ from typing import List, Tuple
23
+
24
+ import cv2
25
+ import numpy as np
26
+
27
+ from emu.constants import EVA_IMAGE_SIZE, GRD_SYMBOL, BOP_SYMBOL, EOP_SYMBOL, BOO_SYMBOL, EOO_SYMBOL
28
+ from emu.constants import DEFAULT_VIDEO_TOKEN, DEFAULT_EOS_TOKEN, USER_TOKEN, ASSISTANT_TOKEN, FAKE_VIDEO_END_TOKEN
29
+
30
+ from .utils import gen_id, frontend_logger as logging
31
+
32
+
33
+ class Role(Enum):
34
+ UNKNOWN = 0,
35
+ USER = 1,
36
+ ASSISTANT = 2,
37
+
38
+
39
+ class DataType(Enum):
40
+ UNKNOWN = 0,
41
+ TEXT = 1,
42
+ IMAGE = 2,
43
+ GROUNDING = 3,
44
+ VIDEO = 4,
45
+ ERROR = 5,
46
+
47
+
48
+ @dataclass
49
+ class DataMeta:
50
+ datatype: DataType = DataType.UNKNOWN
51
+ text: str = None
52
+ image: Image.Image = None
53
+ mask: Image.Image = None
54
+ coordinate: List[int] = None
55
+ frames: List[Image.Image] = None
56
+ stack_frame: Image.Image = None
57
+
58
+ @property
59
+ def grounding(self):
60
+ return self.coordinate is not None
61
+
62
+ @property
63
+ def text_str(self):
64
+ return self.text
65
+
66
+ @property
67
+ def image_str(self):
68
+ return self.image2str(self.image)
69
+
70
+ @property
71
+ def video_str(self):
72
+ ret = f'<div style="overflow:scroll"><b>[VIDEO]</b></div>{self.image2str(self.stack_frame)}'
73
+ return ret
74
+
75
+ @property
76
+ def grounding_str(self):
77
+ ret = ""
78
+ if self.text is not None:
79
+ ret += f'<div style="overflow:scroll"><b>[PHRASE]</b>{self.text}</div>'
80
+
81
+ ret += self.image2str(self.mask)
82
+
83
+ if self.image is not None:
84
+ ret += self.image2str(self.image)
85
+ return ret
86
+
87
+ def image2str(self, image):
88
+ buf = io.BytesIO()
89
+ image.save(buf, format="WEBP")
90
+ i_str = base64.b64encode(buf.getvalue()).decode()
91
+ return f'<div style="float:left"><img src="data:image/png;base64, {i_str}"></div>'
92
+
93
+ def format_chatbot(self):
94
+ match self.datatype:
95
+ case DataType.TEXT:
96
+ return self.text_str
97
+ case DataType.IMAGE:
98
+ return self.image_str
99
+ case DataType.VIDEO:
100
+ return self.video_str
101
+ case DataType.GROUNDING:
102
+ return self.grounding_str
103
+ case _:
104
+ return ""
105
+
106
+ def format_prompt(self) -> List[str | Image.Image]:
107
+ match self.datatype:
108
+ case DataType.TEXT:
109
+ return [self.text]
110
+ case DataType.IMAGE:
111
+ return [self.image]
112
+ case DataType.VIDEO:
113
+ return [DEFAULT_VIDEO_TOKEN] + self.frames + [FAKE_VIDEO_END_TOKEN]
114
+ case DataType.GROUNDING:
115
+ ret = []
116
+ if self.text is not None:
117
+ ret.append(f"{BOP_SYMBOL}{self.text}{EOP_SYMBOL}")
118
+ ret += [BOO_SYMBOL, self.mask, EOO_SYMBOL]
119
+ if self.image is not None:
120
+ ret.append(self.image)
121
+ return ret
122
+ case _:
123
+ return []
124
+
125
+ def __str__(self):
126
+ s = ""
127
+ if self.text is not None:
128
+ s += f"T:{self.text}"
129
+
130
+ if self.image is not None:
131
+ w, h = self.image.size
132
+ s += f"[I:{h}x{w}]"
133
+
134
+ if self.coordinate is not None:
135
+ l, t, r, b = self.coordinate
136
+ s += f"[C:({l:03d},{t:03d}),({r:03d},{b:03d})]"
137
+
138
+ if self.frames is not None:
139
+ w, h = self.frames[0].size
140
+ s += f"[V:{len(self.frames)}x{h}x{w}]"
141
+
142
+ return s
143
+
144
+ @classmethod
145
+ def build(cls, text=None, image=None, coordinate=None, frames=None, is_error=False, *, resize: bool = True):
146
+ ins = cls()
147
+ ins.text = text if text != "" else None
148
+ ins.image = cls.resize(image, force=resize)
149
+ # ins.image = image
150
+ ins.coordinate = cls.fix(coordinate)
151
+ ins.frames = cls.resize(frames, force=resize)
152
+ # ins.frames = frames
153
+
154
+ if is_error:
155
+ ins.datatype = DataType.ERROR
156
+ elif coordinate is not None:
157
+ ins.datatype = DataType.GROUNDING
158
+ ins.draw_box()
159
+ elif image is not None:
160
+ ins.datatype = DataType.IMAGE
161
+ elif text is not None:
162
+ ins.datatype = DataType.TEXT
163
+ else:
164
+ ins.datatype = DataType.VIDEO
165
+ ins.stack()
166
+
167
+ return ins
168
+
169
+ @classmethod
170
+ def fix(cls, coordinate):
171
+ if coordinate is None:
172
+ return None
173
+
174
+ l, t, r, b = coordinate
175
+ l = min(EVA_IMAGE_SIZE, max(0, l))
176
+ t = min(EVA_IMAGE_SIZE, max(0, t))
177
+ r = min(EVA_IMAGE_SIZE, max(0, r))
178
+ b = min(EVA_IMAGE_SIZE, max(0, b))
179
+ return min(l, r), min(t, b), max(l, r), max(t, b)
180
+
181
+ @classmethod
182
+ def resize(cls, image: Image.Image | List[Image.Image] | None, *, force: bool = True):
183
+ if image is None:
184
+ return None
185
+
186
+ if not force:
187
+ return image
188
+
189
+ if isinstance(image, Image.Image):
190
+ image = [image]
191
+
192
+ for idx, im in enumerate(image):
193
+ w, h = im.size
194
+ if w < h:
195
+ h = int(EVA_IMAGE_SIZE / w * h)
196
+ w = EVA_IMAGE_SIZE
197
+ else:
198
+ w = int(EVA_IMAGE_SIZE / h * w)
199
+ h = EVA_IMAGE_SIZE
200
+
201
+ image[idx] = im.resize((w, h))
202
+
203
+ return image if len(image) > 1 else image[0]
204
+
205
+ def draw_box(self):
206
+ left, top, right, bottom = self.coordinate
207
+ mask = np.zeros((EVA_IMAGE_SIZE, EVA_IMAGE_SIZE, 3), dtype=np.uint8)
208
+ mask = cv2.rectangle(mask, (left, top), (right, bottom), (255, 255, 255), 3)
209
+ self.mask = Image.fromarray(mask)
210
+
211
+ def stack(self):
212
+ w, h = self.frames[0].size
213
+ n = len(self.frames)
214
+ stack_frame = Image.new(mode="RGB", size=(w*n, h))
215
+ for idx, f in enumerate(self.frames):
216
+ stack_frame.paste(f, (idx*w, 0))
217
+ self.stack_frame = stack_frame
218
+
219
+
220
+ class ConvMeta:
221
+
222
+ def __init__(self):
223
+ self.system: str = "You are a helpful assistant, dedicated to delivering comprehensive and meticulous responses."
224
+ self.message: List[Tuple[Role, DataMeta]] = []
225
+ self.log_id: str = gen_id()
226
+
227
+ logging.info(f"{self.log_id}: create new round of chat")
228
+
229
+ def append(self, r: Role, p: DataMeta):
230
+ logging.info(f"{self.log_id}: APPEND [{r.name}] prompt element, type: {p.datatype.name}, message: {p}")
231
+ self.message.append((r, p))
232
+
233
+ def format_chatbot(self):
234
+ ret = []
235
+ for r, p in self.message:
236
+ cur_p = p.format_chatbot()
237
+ if r == Role.USER:
238
+ ret.append((cur_p, None))
239
+ else:
240
+ ret.append((None, cur_p))
241
+ return ret
242
+
243
+ def format_prompt(self):
244
+ ret = []
245
+ has_coor = False
246
+ for _, p in self.message:
247
+ has_coor |= (p.datatype == DataType.GROUNDING)
248
+ ret += p.format_prompt()
249
+
250
+ if has_coor:
251
+ ret.insert(0, GRD_SYMBOL)
252
+
253
+ logging.info(f"{self.log_id}: format generation prompt: {ret}")
254
+ return ret
255
+
256
+ def format_chat(self):
257
+ ret = [self.system]
258
+
259
+ prev_r = None
260
+ for r, p in self.message:
261
+ if prev_r != r:
262
+ if prev_r == Role.ASSISTANT:
263
+ ret.append(f"{DEFAULT_EOS_TOKEN}{USER_TOKEN}: ")
264
+ elif prev_r is None:
265
+ ret.append(f" {USER_TOKEN}: ")
266
+ else:
267
+ ret.append(f" {ASSISTANT_TOKEN}: ")
268
+ ret += p.format_prompt()
269
+ prev_r = r
270
+ else:
271
+ ret += p.format_prompt()
272
+
273
+ ret.append(f" {ASSISTANT_TOKEN}:")
274
+
275
+ logging.info(f"{self.log_id}: format chat prompt: {ret}")
276
+ return ret
277
+
278
+ def clear(self):
279
+ logging.info(f"{self.log_id}: clear chat history, end current chat round.")
280
+ del self.message
281
+ self.message = []
282
+ self.log_id = gen_id()
283
+
284
+ def pop(self):
285
+ if self.has_gen:
286
+ logging.info(f"{self.log_id}: pop out previous generation / chat result")
287
+ self.message.pop()
288
+
289
+ def pop_error(self):
290
+ self.message = [(r, p) for r, p in self.message if p.datatype != DataType.ERROR]
291
+
292
+ @property
293
+ def has_gen(self):
294
+ if len(self.message) == 0:
295
+ return False
296
+ if self.message[-1][0] == Role.USER:
297
+ return False
298
+ return True
demo/utils.py ADDED
@@ -0,0 +1,87 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+
3
+ # ===================================================
4
+ #
5
+ # Author : Fan Zhang
6
+ # Email : zhangfan@baai.ac.cn
7
+ # Institute : Beijing Academy of Artificial Intelligence (BAAI)
8
+ # Create On : 2023-12-13 09:48
9
+ # Last Modified : 2023-12-14 01:53
10
+ # File Name : utils.py
11
+ # Description :
12
+ #
13
+ # ===================================================
14
+
15
+ from datetime import datetime
16
+ import logging
17
+ import logging.config
18
+ import hashlib
19
+ import os.path as osp
20
+ import uuid
21
+ from PIL import Image
22
+
23
+ from decord import VideoReader
24
+
25
+
26
+ def config_logger(logger_name):
27
+ logger_config = {
28
+ "version": 1,
29
+ "formatters": {
30
+ "standard": {
31
+ "format": "%(asctime)s - %(filename)s: %(lineno)d - [%(levelname)s] - %(message)s",
32
+ "datefmt": "%Y-%m-%d %H:%M:%S",
33
+ },
34
+ },
35
+ "handlers": {
36
+ "console": {
37
+ "class": "logging.StreamHandler",
38
+ "formatter": "standard",
39
+ "level": "INFO",
40
+ },
41
+ "file": {
42
+ "class": "logging.handlers.TimedRotatingFileHandler",
43
+ "filename": osp.join(osp.dirname(__file__), "..", "log", f"{logger_name}.log"),
44
+ "formatter": "standard",
45
+ "level": "INFO",
46
+ "when": "D",
47
+ "interval": 7,
48
+ "backupCount": 90,
49
+ },
50
+ },
51
+ "loggers": {
52
+ logger_name: {
53
+ "handlers": ["file", "console"],
54
+ "level": "INFO",
55
+ "propagate": True,
56
+ },
57
+ },
58
+ }
59
+
60
+ logging.config.dictConfig(logger_config)
61
+ logger = logging.getLogger(logger_name)
62
+ return logger
63
+
64
+ frontend_logger = config_logger("Emu-v2_frontend")
65
+ beckend_logger = config_logger("Emu-v2_backend")
66
+
67
+
68
+ def extract_frames(video, num_frames):
69
+ video = VideoReader(video)
70
+ total_frames = len(video)
71
+ segment = int(total_frames // num_frames)
72
+
73
+ frames = video.get_batch(list(range(int(segment//2), total_frames, segment))).asnumpy()
74
+ frames = [Image.fromarray(f) for f in frames]
75
+ return frames
76
+
77
+
78
+ def image2md5(image: Image.Image):
79
+ md5hash = hashlib.md5(image.tobytes())
80
+ return md5hash.hexdigest()
81
+
82
+
83
+ def gen_id():
84
+ logid = datetime.now().strftime("%Y%m%d%H%M%d")
85
+ logid += f"{uuid.uuid4().hex}"
86
+ return logid
87
+
requirements.txt ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ gradio
2
+ pillow
3
+ numpy
4
+ opencv-python
5
+ decord