File size: 1,738 Bytes
c18a21e |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 |
import gradio as gr
from demo import VideoCLSModel
sample_videos = [
'data/charades_ego/video/P9SOAEGO.mp4',
'data/charades_ego/video/6D5DHEGO.mp4',
'data/charades_ego/video/15AKPEGO.mp4',
'data/charades_ego/video/X2JTKEGO.mp4',
'data/charades_ego/video/184EHEGO.mp4',
'data/charades_ego/video/S8YZIEGO.mp4',
'data/charades_ego/video/PRODQEGO.mp4',
'data/charades_ego/video/QLXEXEGO.mp4',
'data/charades_ego/video/CC0LBEGO.mp4',
'data/charades_ego/video/FLY2FEGO.mp4'
]
def main():
svitt = VideoCLSModel("configs/charades_ego/svitt.yml")
def predict(video_str):
video_file = video_str.split('/')[-1]
for i, item in enumerate(sample_videos):
if video_file in item:
idx = i
break
ft_action, gt_action = svitt.predict(idx)
return gt_action, ft_action
with gr.Blocks() as demo:
gr.Markdown(
"""
# SViTT-Ego for Action Recognition
Choose a sample video and click predict to view the results.
"""
)
with gr.Row():
idx = gr.Number(label="Idx", visible=False)
video = gr.Video(label='video', format='mp4', autoplay=True, height=256, width=256)
with gr.Row():
label = gr.Text(label="Ground Truth")
ours = gr.Text(label="SViTT-Ego prediction")
with gr.Row():
btn = gr.Button("Predict", variant="primary")
btn.click(predict, inputs=[video], outputs=[label, ours])
with gr.Column():
gr.Examples(examples=[[x] for _, x in enumerate(sample_videos)], inputs=[video])
demo.launch()
if __name__ == "__main__":
main()
|