hvaldez's picture
updating for gpu
ee924e5 verified
import gradio as gr
from demo import VideoCLSModel
sample_videos = [
[
"data/svitt-ego-demo/0/video/2d560d56-dc47-4c76-8d41-889c8aa55d66-converted.mp4",
"data/svitt-ego-demo/0/video/eb5cb2b0-59e6-45da-af1b-ba86c7ab0b54-converted.mp4",
"data/svitt-ego-demo/0/video/0a3097fc-baed-4d11-a4c9-30f07eb91af6-converted.mp4",
"data/svitt-ego-demo/0/video/1a870d5d-5787-4098-ad8d-fe7343c43698-converted.mp4",
"data/svitt-ego-demo/0/video/014b473f-aec0-49c7-b394-abc7309ca3c7-converted.mp4",
],
[
"data/svitt-ego-demo/1/video/029eeb9a-8853-48a4-a1dc-e8868b58adf3-converted.mp4",
"data/svitt-ego-demo/1/video/968139e2-987e-4615-a2d4-fa2e683bae8a-converted.mp4",
"data/svitt-ego-demo/1/video/fb9fda68-f264-465d-9208-19876f5ef90f-converted.mp4",
"data/svitt-ego-demo/1/video/53da674a-089d-428a-a719-e322b2de002b-converted.mp4",
"data/svitt-ego-demo/1/video/060e07d8-e818-4f9c-9d6b-6504f5fd42a3-converted.mp4",
],
[
"data/svitt-ego-demo/2/video/fa2f1291-3796-41a6-8f7b-6e7c1491b9b2-converted.mp4",
"data/svitt-ego-demo/2/video/8d83478f-c5d2-4ac3-a823-e1b2ac7594d7-converted.mp4",
"data/svitt-ego-demo/2/video/5f6f87ea-e1c3-4868-bb60-22c9e874d056-converted.mp4",
"data/svitt-ego-demo/2/video/77718528-2de9-48b4-b6b8-e7c602032afb-converted.mp4",
"data/svitt-ego-demo/2/video/9abbf7f4-68f0-4f52-812f-df2a3df48f7b-converted.mp4",
],
[
"data/svitt-ego-demo/3/video/2a6b3d10-8da9-4f0e-a681-59ba48a55dbf-converted.mp4",
"data/svitt-ego-demo/3/video/5afd7421-fb6b-4c65-a09a-716f79a7a935-converted.mp4",
"data/svitt-ego-demo/3/video/f7aec252-bd4f-4696-8de5-ef7b871e2194-converted.mp4",
"data/svitt-ego-demo/3/video/84d6855a-242b-44a6-b48d-2db302b5ea7a-converted.mp4",
"data/svitt-ego-demo/3/video/81fff27c-97c0-483a-ad42-47fa947977a9-converted.mp4",
],
]
sample_text = [
"drops the palm fronds on the ground",
"stands up",
"throws nuts in a bowl",
"puts the speaker and notepad in both hands on a seat",
]
sample_text_dict = {
"drops the palm fronds on the ground": 0,
"stands up": 1,
"throws nuts in a bowl": 2,
"puts the speaker and notepad in both hands on a seat": 3,
}
num_samples = len(sample_videos[0])
labels = [f"video-{i}" for i in range(num_samples)]
def main():
svitt = VideoCLSModel(
"configs/ego_mcq/svitt.yml",
sample_videos,
)
def predict(text):
idx = sample_text_dict[text]
ft_action, gt_action = svitt.predict(idx, text)
return labels[gt_action], labels[ft_action]
with gr.Blocks() as demo:
gr.Markdown(
"""
# SViTT-Ego for Multiple Choice Question
Choose a sample query and click predict to view the results.
"""
)
with gr.Row():
videos = [
gr.Video(label=labels[i], format='mp4', height=256, width=256, autoplay=True)
for i in range(num_samples)
]
with gr.Row():
text = gr.Text(label="Query", visible=False)
label = gr.Text(label="Ground Truth")
ours = gr.Text(label="SViTT-Ego prediction")
btn = gr.Button("Predict", variant="primary")
btn.click(predict, inputs=[text], outputs=[label, ours])
inputs = [text]
inputs.extend(videos)
gr.Examples(examples=[[sample_text[i], x[0], x[1], x[2], x[3], x[4]] for i, x in enumerate(sample_videos)], inputs=inputs)
demo.launch()
if __name__ == "__main__":
main()