File size: 1,738 Bytes
c18a21e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
import gradio as gr

from demo import VideoCLSModel

sample_videos = [
   'data/charades_ego/video/P9SOAEGO.mp4',
   'data/charades_ego/video/6D5DHEGO.mp4',
   'data/charades_ego/video/15AKPEGO.mp4',
   'data/charades_ego/video/X2JTKEGO.mp4',
   'data/charades_ego/video/184EHEGO.mp4',
   'data/charades_ego/video/S8YZIEGO.mp4',
   'data/charades_ego/video/PRODQEGO.mp4',
   'data/charades_ego/video/QLXEXEGO.mp4',
   'data/charades_ego/video/CC0LBEGO.mp4',
   'data/charades_ego/video/FLY2FEGO.mp4'
]

def main():
    svitt = VideoCLSModel("configs/charades_ego/svitt.yml")

    def predict(video_str):
        video_file = video_str.split('/')[-1]
        for i, item in enumerate(sample_videos):
            if video_file in item:
                idx = i
                break
        
        ft_action, gt_action = svitt.predict(idx)

        return gt_action, ft_action

    with gr.Blocks() as demo:
        gr.Markdown(
            """
            # SViTT-Ego for Action Recognition 
            Choose a sample video and click predict to view the results.
            """
        )
        with gr.Row():
            idx = gr.Number(label="Idx", visible=False)
            video = gr.Video(label='video', format='mp4', autoplay=True, height=256, width=256)
            with gr.Row():
                label = gr.Text(label="Ground Truth")
                ours = gr.Text(label="SViTT-Ego prediction")
        with gr.Row():
            btn = gr.Button("Predict", variant="primary")
        btn.click(predict, inputs=[video], outputs=[label, ours])
        with gr.Column():
            gr.Examples(examples=[[x] for _, x in enumerate(sample_videos)], inputs=[video])

    demo.launch()


if __name__ == "__main__":
    main()