import os import torch import gradio as gr from transformers import XCLIPProcessor, XCLIPModel from utils import convert_frames_to_gif, download_youtube_video, sample_frames_from_video_file model_name = "microsoft/xclip-base-patch16-zero-shot" processor = XCLIPProcessor.from_pretrained(model_name) model = XCLIPModel.from_pretrained(model_name) examples = [ ["https://www.youtu.be/l1dBM8ZECao", "sleeping dog,cat fight club,birds of prey"], ["https://youtu.be/VMj-3S1tku0", "programming course,eating spaghetti,playing football"], ["https://www.youtu.be/x8UAUAuKNcU", "game of thrones,the lord of the rings,vikings"] ] def predict(youtube_url, labels_text): labels = labels_text.split(",") video_path = download_youtube_video(youtube_url) frames = sample_frames_from_video_file(video_path, num_frames=32) os.remove(video_path) gif_path = convert_frames_to_gif(frames) inputs = processor( text=labels, videos=list(frames), return_tensors="pt", padding=True ) # forward pass with torch.no_grad(): outputs = model(**inputs) probs = outputs.logits_per_video[0].softmax(dim=-1).cpu().numpy() label_to_prob = {} for ind, label in enumerate(labels): label_to_prob[label] = float(probs[ind]) return label_to_prob, gif_path app = gr.Blocks() with app: gr.Markdown("# **

Zero-shot Video Classification with X-CLIP

**") with gr.Row(): with gr.Column(): gr.Markdown("Provide a Youtube video URL and a list of labels separated by commas") youtube_url = gr.Textbox(label="Youtube URL:", show_label=True) labels_text = gr.Textbox(label="Labels Text:", show_label=True) predict_btn = gr.Button(value="Predict") with gr.Column(): video_gif = gr.Image(label="Input Clip", show_label=True,) with gr.Column(): predictions = gr.Label(label='Predictions:', show_label=True) gr.Markdown("**Examples:**") gr.Examples(examples, [youtube_url, labels_text], [predictions, video_gif], fn=predict, cache_examples=True) predict_btn.click(predict, inputs=[youtube_url, labels_text], outputs=[predictions, video_gif]) gr.Markdown( """ \n Demo created by: fcakyon
Based on this HuggingFace model """ ) app.launch()