Spaces:
Runtime error
Runtime error
Update demo.
Browse files
app.py
CHANGED
@@ -102,6 +102,7 @@ class Chat:
|
|
102 |
# 2. text preprocess (tag process & generate prompt).
|
103 |
state = self.get_prompt(prompt, state)
|
104 |
prompt = state.get_prompt()
|
|
|
105 |
input_ids = tokenizer_MMODAL_token(prompt, tokenizer, MMODAL_TOKEN_INDEX[modals[0]], return_tensors='pt')
|
106 |
input_ids = input_ids.unsqueeze(0).to(self.model.device)
|
107 |
|
@@ -130,15 +131,13 @@ class Chat:
|
|
130 |
|
131 |
|
132 |
@spaces.GPU(duration=120)
|
133 |
-
def generate(image, video,
|
134 |
-
flag = 1
|
135 |
if not textbox_in:
|
136 |
if len(state_.messages) > 0:
|
137 |
textbox_in = state_.messages[-1][1]
|
138 |
state_.messages.pop(-1)
|
139 |
-
flag = 0
|
140 |
else:
|
141 |
-
|
142 |
|
143 |
image = image if image else "none"
|
144 |
video = video if video else "none"
|
@@ -187,30 +186,34 @@ def generate(image, video, first_run, state, state_, textbox_in, temperature, to
|
|
187 |
if os.path.exists(video):
|
188 |
show_images += f'<video controls playsinline width="500" style="display: inline-block;" src="./file={video}"></video>'
|
189 |
|
190 |
-
|
191 |
-
state.append_message(state.roles[0], textbox_in + "\n" + show_images)
|
192 |
state.append_message(state.roles[1], textbox_out)
|
193 |
|
194 |
-
|
195 |
-
|
|
|
|
|
|
|
|
|
|
|
196 |
|
197 |
|
198 |
-
def regenerate(state, state_
|
|
|
199 |
state.messages.pop(-1)
|
200 |
-
state_.messages.pop(-1)
|
201 |
-
textbox = gr.update(value=None, interactive=True)
|
202 |
if len(state.messages) > 0:
|
203 |
-
return state
|
204 |
-
return state
|
205 |
|
206 |
|
207 |
def clear_history(state, state_):
|
208 |
state = conv_templates[conv_mode].copy()
|
209 |
state_ = conv_templates[conv_mode].copy()
|
210 |
return (gr.update(value=None, interactive=True),
|
211 |
-
gr.update(value=None, interactive=True),
|
212 |
-
state.to_gradio_chatbot(),
|
213 |
-
|
|
|
214 |
|
215 |
# BUG of Zero Environment
|
216 |
# 1. The environment is fixed to torch==2.0.1+cu117, gradio>=4.x.x
|
@@ -230,7 +233,6 @@ with gr.Blocks(title='VideoLLaMA 2 π₯ππ₯', theme=gr.themes.Default(primar
|
|
230 |
gr.Markdown(title_markdown)
|
231 |
state = gr.State()
|
232 |
state_ = gr.State()
|
233 |
-
first_run = gr.State()
|
234 |
|
235 |
with gr.Row():
|
236 |
with gr.Column(scale=3):
|
@@ -331,20 +333,20 @@ with gr.Blocks(title='VideoLLaMA 2 π₯ππ₯', theme=gr.themes.Default(primar
|
|
331 |
|
332 |
submit_btn.click(
|
333 |
generate,
|
334 |
-
[image, video,
|
335 |
-
[image, video, chatbot,
|
336 |
|
337 |
regenerate_btn.click(
|
338 |
regenerate,
|
339 |
-
[state, state_
|
340 |
-
[state, state_
|
341 |
generate,
|
342 |
-
[image, video,
|
343 |
-
[image, video, chatbot,
|
344 |
|
345 |
clear_btn.click(
|
346 |
clear_history,
|
347 |
[state, state_],
|
348 |
-
[image, video, chatbot,
|
349 |
|
350 |
demo.launch()
|
|
|
102 |
# 2. text preprocess (tag process & generate prompt).
|
103 |
state = self.get_prompt(prompt, state)
|
104 |
prompt = state.get_prompt()
|
105 |
+
|
106 |
input_ids = tokenizer_MMODAL_token(prompt, tokenizer, MMODAL_TOKEN_INDEX[modals[0]], return_tensors='pt')
|
107 |
input_ids = input_ids.unsqueeze(0).to(self.model.device)
|
108 |
|
|
|
131 |
|
132 |
|
133 |
@spaces.GPU(duration=120)
|
134 |
+
def generate(image, video, state, state_, textbox_in, temperature, top_p, max_output_tokens, dtype=torch.float16):
|
|
|
135 |
if not textbox_in:
|
136 |
if len(state_.messages) > 0:
|
137 |
textbox_in = state_.messages[-1][1]
|
138 |
state_.messages.pop(-1)
|
|
|
139 |
else:
|
140 |
+
assert "Please enter instruction"
|
141 |
|
142 |
image = image if image else "none"
|
143 |
video = video if video else "none"
|
|
|
186 |
if os.path.exists(video):
|
187 |
show_images += f'<video controls playsinline width="500" style="display: inline-block;" src="./file={video}"></video>'
|
188 |
|
189 |
+
state.append_message(state.roles[0], textbox_in + "\n" + show_images)
|
|
|
190 |
state.append_message(state.roles[1], textbox_out)
|
191 |
|
192 |
+
# BUG: only support single turn conversation now.
|
193 |
+
state_.messages.pop(-1)
|
194 |
+
state_.messages.pop(-1)
|
195 |
+
|
196 |
+
return (gr.update(value=image if os.path.exists(image) else None, interactive=True),
|
197 |
+
gr.update(value=video if os.path.exists(video) else None, interactive=True),
|
198 |
+
state.to_gradio_chatbot(), state, state_)
|
199 |
|
200 |
|
201 |
+
def regenerate(state, state_):
|
202 |
+
state.messages.pop(-1)
|
203 |
state.messages.pop(-1)
|
|
|
|
|
204 |
if len(state.messages) > 0:
|
205 |
+
return state.to_gradio_chatbot(), state, state_
|
206 |
+
return state.to_gradio_chatbot(), state, state_
|
207 |
|
208 |
|
209 |
def clear_history(state, state_):
|
210 |
state = conv_templates[conv_mode].copy()
|
211 |
state_ = conv_templates[conv_mode].copy()
|
212 |
return (gr.update(value=None, interactive=True),
|
213 |
+
gr.update(value=None, interactive=True),
|
214 |
+
state.to_gradio_chatbot(), state, state_,
|
215 |
+
gr.update(value=None, interactive=True))
|
216 |
+
|
217 |
|
218 |
# BUG of Zero Environment
|
219 |
# 1. The environment is fixed to torch==2.0.1+cu117, gradio>=4.x.x
|
|
|
233 |
gr.Markdown(title_markdown)
|
234 |
state = gr.State()
|
235 |
state_ = gr.State()
|
|
|
236 |
|
237 |
with gr.Row():
|
238 |
with gr.Column(scale=3):
|
|
|
333 |
|
334 |
submit_btn.click(
|
335 |
generate,
|
336 |
+
[image, video, state, state_, textbox, temperature, top_p, max_output_tokens],
|
337 |
+
[image, video, chatbot, state, state_])
|
338 |
|
339 |
regenerate_btn.click(
|
340 |
regenerate,
|
341 |
+
[state, state_],
|
342 |
+
[chatbot, state, state_]).then(
|
343 |
generate,
|
344 |
+
[image, video, state, state_, textbox, temperature, top_p, max_output_tokens],
|
345 |
+
[image, video, chatbot, state, state_])
|
346 |
|
347 |
clear_btn.click(
|
348 |
clear_history,
|
349 |
[state, state_],
|
350 |
+
[image, video, chatbot, state, state_, textbox])
|
351 |
|
352 |
demo.launch()
|