|
import gradio as gr |
|
|
|
from demo import VideoCLSModel |
|
|
|
sample_videos = [ |
|
'data/charades_ego/video/P9SOAEGO.mp4', |
|
'data/charades_ego/video/6D5DHEGO.mp4', |
|
'data/charades_ego/video/15AKPEGO.mp4', |
|
'data/charades_ego/video/X2JTKEGO.mp4', |
|
'data/charades_ego/video/184EHEGO.mp4', |
|
'data/charades_ego/video/S8YZIEGO.mp4', |
|
'data/charades_ego/video/PRODQEGO.mp4', |
|
'data/charades_ego/video/QLXEXEGO.mp4', |
|
'data/charades_ego/video/CC0LBEGO.mp4', |
|
'data/charades_ego/video/FLY2FEGO.mp4' |
|
] |
|
|
|
def main(): |
|
svitt = VideoCLSModel("configs/charades_ego/svitt.yml") |
|
|
|
def predict(video_str): |
|
video_file = video_str.split('/')[-1] |
|
for i, item in enumerate(sample_videos): |
|
if video_file in item: |
|
idx = i |
|
break |
|
|
|
ft_action, gt_action = svitt.predict(idx) |
|
|
|
return gt_action, ft_action |
|
|
|
with gr.Blocks() as demo: |
|
gr.Markdown( |
|
""" |
|
# SViTT-Ego for Action Recognition |
|
Choose a sample video and click predict to view the results. |
|
""" |
|
) |
|
with gr.Row(): |
|
idx = gr.Number(label="Idx", visible=False) |
|
video = gr.Video(label='video', format='mp4', autoplay=True, height=256, width=256) |
|
with gr.Row(): |
|
label = gr.Text(label="Ground Truth") |
|
ours = gr.Text(label="SViTT-Ego prediction") |
|
with gr.Row(): |
|
btn = gr.Button("Predict", variant="primary") |
|
btn.click(predict, inputs=[video], outputs=[label, ours]) |
|
with gr.Column(): |
|
gr.Examples(examples=[[x] for _, x in enumerate(sample_videos)], inputs=[video]) |
|
|
|
demo.launch() |
|
|
|
|
|
if __name__ == "__main__": |
|
main() |
|
|
|
|