Spaces:
Runtime error
Runtime error
import os | |
import torch | |
import gradio as gr | |
from transformers import XCLIPProcessor, XCLIPModel | |
from utils import convert_frames_to_gif, download_youtube_video, sample_frames_from_video_file | |
model_name = "microsoft/xclip-base-patch16-zero-shot" | |
processor = XCLIPProcessor.from_pretrained(model_name) | |
model = XCLIPModel.from_pretrained(model_name) | |
examples = [ | |
["https://www.youtu.be/l1dBM8ZECao", "sleeping dog,cat fight club,birds of prey"], | |
["https://youtu.be/VMj-3S1tku0", "programming course,eating spaghetti,playing football"], | |
["https://www.youtu.be/x8UAUAuKNcU", "game of thrones,the lord of the rings,vikings"] | |
] | |
def predict(youtube_url, labels_text): | |
labels = labels_text.split(",") | |
video_path = download_youtube_video(youtube_url) | |
frames = sample_frames_from_video_file(video_path, num_frames=32) | |
os.remove(video_path) | |
gif_path = convert_frames_to_gif(frames) | |
inputs = processor( | |
text=labels, | |
videos=list(frames), | |
return_tensors="pt", | |
padding=True | |
) | |
# forward pass | |
with torch.no_grad(): | |
outputs = model(**inputs) | |
probs = outputs.logits_per_video[0].softmax(dim=-1).cpu().numpy() | |
label_to_prob = {} | |
for ind, label in enumerate(labels): | |
label_to_prob[label] = float(probs[ind]) | |
return label_to_prob, gif_path | |
app = gr.Blocks() | |
with app: | |
gr.Markdown("# **<p align='center'>Zero-shot Video Classification with X-CLIP</p>**") | |
with gr.Row(): | |
with gr.Column(): | |
gr.Markdown("Provide a Youtube video URL and a list of labels separated by commas") | |
youtube_url = gr.Textbox(label="Youtube URL:", show_label=True) | |
labels_text = gr.Textbox(label="Labels Text:", show_label=True) | |
predict_btn = gr.Button(value="Predict") | |
with gr.Column(): | |
video_gif = gr.Image(label="Input Clip", show_label=True,) | |
with gr.Column(): | |
predictions = gr.Label(label='Predictions:', show_label=True) | |
gr.Markdown("**Examples:**") | |
gr.Examples(examples, [youtube_url, labels_text], [predictions, video_gif], fn=predict, cache_examples=True) | |
predict_btn.click(predict, inputs=[youtube_url, labels_text], outputs=[predictions, video_gif]) | |
gr.Markdown( | |
""" | |
\n Demo created by: <a href=\"https://github.com/fcakyon\">fcakyon</a> | |
<br> Based on this <a href=\"https://huggingface.co/microsoft/xclip-base-patch16-zero-shot\">HuggingFace model</a> | |
""" | |
) | |
app.launch() | |