SeViLA / app.py
Shoubin's picture
Update app.py
ff1c51e
import gradio as gr
import os
import torch
from torchvision import transforms
from lavis.processors import transforms_video
from lavis.datasets.data_utils import load_video_demo
from lavis.processors.blip_processors import ToUint8, ToTHWC
from lavis.models.sevila_models.sevila import SeViLA
from typing import Optional
import warnings
# model config
img_size = 224
num_query_token = 32
t5_model = 'google/flan-t5-xl'
drop_path_rate = 0
use_grad_checkpoint = False
vit_precision = "fp16"
freeze_vit = True
prompt = ''
max_txt_len = 77
answer_num = 5
apply_lemmatizer = False
task = 'freeze_loc_freeze_qa_vid'
# prompt
LOC_propmpt = 'Does the information within the frame provide the necessary details to accurately answer the given question?'
QA_prompt = 'Considering the information presented in the frame, select the correct answer from the options.'
# processors config
mean = (0.48145466, 0.4578275, 0.40821073)
std = (0.26862954, 0.26130258, 0.27577711)
normalize = transforms.Normalize(mean, std)
image_size = img_size
transform = transforms.Compose([ToUint8(), ToTHWC(), transforms_video.ToTensorVideo(), normalize])
print('model loading')
sevila = SeViLA(
img_size=img_size,
drop_path_rate=drop_path_rate,
use_grad_checkpoint=use_grad_checkpoint,
vit_precision=vit_precision,
freeze_vit=freeze_vit,
num_query_token=num_query_token,
t5_model=t5_model,
prompt=prompt,
max_txt_len=max_txt_len,
apply_lemmatizer=apply_lemmatizer,
frame_num=4,
answer_num=answer_num,
task=task,
)
sevila.load_checkpoint(url_or_filename='https://huggingface.co/Shoubin/SeViLA/resolve/main/sevila_pretrained.pth')
print('model loaded')
ANS_MAPPING = {0 : 'A', 1 : 'B', 2 : 'C', 3 : 'D', 4 : 'E'}
# os.mkdir('video')
def sevila_demo(video,
question,
option1, option2, option3,
video_frame_num,
keyframe_num):
if torch.cuda.is_available():
device = 0
else:
device = 'cpu'
global sevila
if device == "cpu":
sevila = sevila.float()
else:
sevila = sevila.to(int(device))
vpath = video
raw_clip, indice, fps, vlen = load_video_demo(
video_path=vpath,
n_frms=int(video_frame_num),
height=image_size,
width=image_size,
sampling="uniform",
clip_proposal=None
)
clip = transform(raw_clip.permute(1,0,2,3))
if device == "cpu":
clip = clip.float()
else:
clip = clip.float().to(int(device))
clip = clip.unsqueeze(0)
# check
if option1[-1] != '.':
option1 += '.'
if option2[-1] != '.':
option2 += '.'
if option3[-1] != '.':
option3 += '.'
option_dict = {0:option1, 1:option2, 2:option3}
options = 'Option A:{} Option B:{} Option C:{}'.format(option1, option2, option3)
text_input_qa = 'Question: ' + question + ' ' + options + ' ' + QA_prompt
text_input_loc = 'Question: ' + question + ' ' + options + ' ' + LOC_propmpt
out = sevila.generate_demo(clip, text_input_qa, text_input_loc, int(keyframe_num))
# print(out)
answer_id = out['output_text'][0]
answer = option_dict[answer_id]
select_index = out['frame_idx'][0]
# images = []
keyframes = []
timestamps =[]
# print('raw_clip', len(raw_clip))
# for j in range(int(video_frame_num)):
# image = raw_clip[:, j, :, :].int()
# image = image.permute(1, 2, 0).numpy()
# images.append(image)
video_len = vlen/fps # seconds
for i in select_index:
image = raw_clip[:, i, :, :].int()
image = image.permute(1, 2, 0).numpy()
keyframes.append(image)
select_i = indice[i]
time = round((select_i / vlen) * video_len, 2)
timestamps.append(str(time)+'s')
gr.components.Gallery(keyframes)
#gr.components.Gallery(images)
timestamps_des = ''
for i in range(len(select_index)):
timestamps_des += 'Keyframe {}: {} \n'.format(str(i+1), timestamps[i])
return keyframes, timestamps_des, answer
with gr.Blocks(title="SeViLA demo") as demo:
description = """<p style="text-align: center; font-weight: bold;">
<span style="font-size: 28px">Self-Chained Image-Language Model for Video Localization and Question Answering</span>
<br>
<span style="font-size: 18px" id="author-info">
<a href="https://yui010206.github.io/" target="_blank">Shoubin Yu</a>,
<a href="https://j-min.io/" target="_blank">Jaemin Cho</a>,
<a href="https://prateek-yadav.github.io/" target="_blank">Prateek Yadav</a>,
<a href="https://www.cs.unc.edu/~mbansal/" target="_blank">Mohit Bansal</a>
</span>
<br>
<span style="font-size: 18px" id="paper-info">
[<a href="https://github.com/Yui010206/SeViLA" target="_blank">GitHub</a>]
[<a href="https://arxiv.org/abs/2305.06988" target="_blank">Paper</a>]
</span>
</p>
<p>
To locate keyframes in a video and answer question, please:
<br>
(1) upolad your video; (2) write your question/options and set # video frame/# keyframe; (3) click Locate and Answer!
<br>
Just a heads up - loading the SeViLA model can take a few minutes (typically 2-3), and running examples requires about 12GB of memory.
<br>
We've got you covered! We've provided some example videos and questions below to help you get started. Feel free to try out SeViLA with these!
</p>
"""
gr.HTML(description)
with gr.Row():
with gr.Column(scale=1, min_width=600):
video = gr.Video(label='Video')
question = gr.Textbox(placeholder="Why did the two ladies put their hands above their eyes while staring out?", label='Question')
with gr.Row():
option1 = gr.Textbox(placeholder="practicing cheer", label='Option 1')
option2 = gr.Textbox(placeholder="posing for photo", label='Option 2')
option3 = gr.Textbox(placeholder="to see better", label='Option 3')
with gr.Row():
video_frame_num = gr.Textbox(placeholder=32, label='# Video Frame')
keyframe_num = gr.Textbox(placeholder=4, label='# Keyframe')
# device = gr.Textbox(placeholder=0, label='Device')
gen_btn = gr.Button(value='Locate and Answer!')
with gr.Column(scale=1, min_width=600):
keyframes = gr.Gallery(
label="Keyframes", show_label=False, elem_id="gallery",
).style(columns=[4], rows=[1], object_fit="contain", max_width=100, max_height=100)
#keyframes = gr.Gallery(label='Keyframes')
timestamps = gr.outputs.Textbox(label="Keyframe Timestamps")
answer = gr.outputs.Textbox(label="Output Answer")
gen_btn.click(
sevila_demo,
inputs=[video, question, option1, option2, option3, video_frame_num, keyframe_num],
outputs=[keyframes, timestamps, answer],
queue=True
)
#demo = gr.Interface(sevila_demo,
# inputs=[gr.Video(), question, option1, option2, option3, video_frame_num, keyframe_num, device],
# outputs=['gallery', timestamps, answer],
# examples=[['videos/demo1.mp4', 'Why did the two ladies put their hands above their eyes while staring out?', 'practicing cheer.', 'play ball.', 'to see better.', 32, 4, 0],
# ['videos/demo2.mp4', 'What did both of them do after completing skiing?', 'jump and pose.' , 'bend down.','raised their hands.', 32, 4, 0],
# ['videos/demo3.mp4', 'What room was Wilson breaking into when House found him?', 'the kitchen.' , 'the dining room.','the bathroom.', 32, 4, 0]]
# )
with gr.Column():
gr.Examples(
inputs=[video, question, option1, option2, option3, video_frame_num, keyframe_num],
outputs=[keyframes, timestamps, answer],
fn=sevila_demo,
examples=[['videos/demo1.mp4', 'Why did the two ladies put their hands above their eyes while staring out?', 'practicing cheer', 'to place wreaths', 'to see better', 32, 4],
['videos/demo2.mp4', 'What did both of them do after completing skiing?', 'jump and pose' , 'bend down','raised their hands', 32, 4],
['videos/demo3.mp4', 'What room was Wilson breaking into when House found him?', 'the bedroom' , 'the bathroom','the kitchen', 32, 4],
['videos/demo4.mp4', 'What kind of bird is it?', 'chikadee' , 'eagle', 'sparrow', 32, 1]],
cache_examples=False,
)
demo.queue(concurrency_count=1, api_open=False)
demo.launch(share=False)