import gradio as gr from demo import VideoCLSModel sample_videos = [ 'data/charades_ego/video/P9SOAEGO.mp4', 'data/charades_ego/video/6D5DHEGO.mp4', 'data/charades_ego/video/15AKPEGO.mp4', 'data/charades_ego/video/X2JTKEGO.mp4', 'data/charades_ego/video/184EHEGO.mp4', 'data/charades_ego/video/S8YZIEGO.mp4', 'data/charades_ego/video/PRODQEGO.mp4', 'data/charades_ego/video/QLXEXEGO.mp4', 'data/charades_ego/video/CC0LBEGO.mp4', 'data/charades_ego/video/FLY2FEGO.mp4' ] def main(): svitt = VideoCLSModel("configs/charades_ego/svitt.yml") def predict(video_str): video_file = video_str.split('/')[-1] for i, item in enumerate(sample_videos): if video_file in item: idx = i break ft_action, gt_action = svitt.predict(idx) return gt_action, ft_action with gr.Blocks() as demo: gr.Markdown( """ # SViTT-Ego for Action Recognition Choose a sample video and click predict to view the results. """ ) with gr.Row(): idx = gr.Number(label="Idx", visible=False) video = gr.Video(label='video', format='mp4', autoplay=True, height=256, width=256) with gr.Row(): label = gr.Text(label="Ground Truth") ours = gr.Text(label="SViTT-Ego prediction") with gr.Row(): btn = gr.Button("Predict", variant="primary") btn.click(predict, inputs=[video], outputs=[label, ours]) with gr.Column(): gr.Examples(examples=[[x] for _, x in enumerate(sample_videos)], inputs=[video]) demo.launch() if __name__ == "__main__": main()