File size: 5,904 Bytes
eb0678a 6a6199b eb0678a cc32390 eb0678a cc32390 eb0678a 6274907 eb0678a 6993849 eb0678a |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 |
import torch
import gradio as gr
from flash_vstream.serve.demo import Chat, title_markdown, block_css
from flash_vstream.constants import *
from flash_vstream.conversation import conv_templates, Conversation
import os
from PIL import Image
import tempfile
import imageio
import shutil
model_path = "IVGSZ/Flash-VStream-7b"
load_8bit = False
load_4bit = False
def save_image_to_local(image):
filename = os.path.join('temp', next(tempfile._get_candidate_names()) + '.jpg')
image = Image.open(image)
image.save(filename)
return filename
def save_video_to_local(video_path):
filename = os.path.join('temp', next(tempfile._get_candidate_names()) + '.mp4')
shutil.copyfile(video_path, filename)
return filename
def generate(video, textbox_in, first_run, state, state_, images_tensor):
flag = 1
if not textbox_in:
if len(state_.messages) > 0:
textbox_in = state_.messages[-1][1]
state_.messages.pop(-1)
flag = 0
else:
return "Please enter instruction"
video = video if video else "none"
if type(state) is not Conversation:
state = conv_templates[conv_mode].copy()
state_ = conv_templates[conv_mode].copy()
images_tensor = []
first_run = False if len(state.messages) > 0 else True
text_en_in = textbox_in.replace("picture", "image")
image_processor = handler.image_processor
if os.path.exists(video):
video_tensor = handler._get_rawvideo_dec(video, image_processor, max_frames=MAX_IMAGE_LENGTH)
images_tensor = image_processor(video_tensor, return_tensors='pt')['pixel_values'].to(handler.model.device, dtype=torch.float16)
print("video_tensor", video_tensor.shape)
if os.path.exists(video):
text_en_in = DEFAULT_IMAGE_TOKEN + '\n' + text_en_in
text_en_out, state_ = handler.generate(images_tensor, text_en_in, first_run=first_run, state=state_)
state_.messages[-1] = (state_.roles[1], text_en_out)
text_en_out = text_en_out.split('#')[0]
textbox_out = text_en_out
show_images = ""
if os.path.exists(video):
filename = save_video_to_local(video)
show_images += f'<video controls playsinline width="500" style="display: inline-block;" src="./file={filename}"></video>'
if flag:
state.append_message(state.roles[0], textbox_in + "\n" + show_images)
state.append_message(state.roles[1], textbox_out)
return (state, state_, state.to_gradio_chatbot(), False, gr.update(value=None, interactive=True), images_tensor, gr.update(value=None, interactive=True))
def regenerate(state, state_):
state.messages.pop(-1)
state_.messages.pop(-1)
if len(state.messages) > 0:
return state, state_, state.to_gradio_chatbot(), False
return (state, state_, state.to_gradio_chatbot(), True)
def clear_history(state, state_):
state = conv_templates[conv_mode].copy()
state_ = conv_templates[conv_mode].copy()
return (gr.update(value=None, interactive=True), \
gr.update(value=None, interactive=True),\
True, state, state_, state.to_gradio_chatbot(), [])
conv_mode = "vicuna_v1"
handler = Chat(model_path, conv_mode=conv_mode, load_4bit=load_4bit, load_8bit=load_8bit)
if not os.path.exists("temp"):
os.makedirs("temp")
print(torch.cuda.memory_allocated())
print(torch.cuda.max_memory_allocated())
with gr.Blocks(title='Flash-VStream', theme=gr.themes.Soft(), css=block_css) as demo:
gr.Markdown(title_markdown)
state = gr.State()
state_ = gr.State()
first_run = gr.State()
images_tensor = gr.State()
with gr.Row():
with gr.Column(scale=3):
video = gr.Video(label="Input Video")
with gr.Column(scale=7):
chatbot = gr.Chatbot(label="Flash-VStream", bubble_full_width=True).style(height=700)
with gr.Row():
with gr.Column(scale=8):
textbox = gr.Textbox(show_label=False,
placeholder="Enter text and press Send",
container=False)
with gr.Column(scale=2, min_width=50):
submit_btn = gr.Button(value="Send", variant="primary", interactive=True)
with gr.Row(visible=True) as button_row:
regenerate_btn = gr.Button(value="🔄 Regenerate", interactive=True)
clear_btn = gr.Button(value="🗑️ Clear history", interactive=True)
cur_dir = os.path.dirname(os.path.abspath(__file__))
with gr.Row():
gr.Examples(
examples=[
[
f"{cur_dir}/examples/video1.mp4",
"Describe the video briefly.",
]
],
inputs=[video, textbox],
)
gr.Examples(
examples=[
[
f"{cur_dir}/examples/video4.mp4",
"What is the boy doing?",
]
],
inputs=[video, textbox],
)
gr.Examples(
examples=[
[
f"{cur_dir}/examples/video5.mp4",
"Why is this video funny?",
]
],
inputs=[video, textbox],
)
submit_btn.click(generate, [video, textbox, first_run, state, state_, images_tensor], [state, state_, chatbot, first_run, textbox, images_tensor, video])
regenerate_btn.click(regenerate, [state, state_], [state, state_, chatbot, first_run]).then(
generate, [video, textbox, first_run, state, state_, images_tensor], [state, state_, chatbot, first_run, textbox, images_tensor, video])
clear_btn.click(clear_history, [state, state_],
[video, textbox, first_run, state, state_, chatbot, images_tensor])
# app = gr.mount_gradio_app(app, demo, path="/")
demo.launch() |