|
import torch |
|
import gradio as gr |
|
from transformers import AutoProcessor, AutoModel |
|
from utils import ( |
|
convert_frames_to_gif, |
|
download_youtube_video, |
|
get_num_total_frames, |
|
sample_frames_from_video_file, |
|
) |
|
|
|
FRAME_SAMPLING_RATE = 4 |
|
DEFAULT_MODEL = "microsoft/xclip-base-patch16-zero-shot" |
|
|
|
VALID_ZEROSHOT_VIDEOCLASSIFICATION_MODELS = [ |
|
"microsoft/xclip-base-patch32", |
|
"microsoft/xclip-base-patch16-zero-shot", |
|
"microsoft/xclip-base-patch16-kinetics-600", |
|
"microsoft/xclip-large-patch14ft/xclip-base-patch32-16-frames", |
|
"microsoft/xclip-large-patch14", |
|
"microsoft/xclip-base-patch16-hmdb-4-shot", |
|
"microsoft/xclip-base-patch16-16-frames", |
|
"microsoft/xclip-base-patch16-hmdb-2-shot", |
|
"microsoft/xclip-base-patch16-ucf-2-shot", |
|
"microsoft/xclip-base-patch16-ucf-8-shot", |
|
"microsoft/xclip-base-patch16", |
|
"microsoft/xclip-base-patch16-hmdb-8-shot", |
|
"microsoft/xclip-base-patch16-hmdb-16-shot", |
|
"microsoft/xclip-base-patch16-ucf-16-shot", |
|
] |
|
|
|
processor = AutoProcessor.from_pretrained(DEFAULT_MODEL) |
|
model = AutoModel.from_pretrained(DEFAULT_MODEL) |
|
|
|
examples = [ |
|
[ |
|
"https://www.youtu.be/l1dBM8ZECao", |
|
"sleeping dog,cat fight club,birds of prey", |
|
], |
|
[ |
|
"https://youtu.be/VMj-3S1tku0", |
|
"programming course,eating spaghetti,playing football", |
|
], |
|
[ |
|
"https://youtu.be/Tm6BlRMEny0", |
|
"game of thrones,the lord of the rings,vikings", |
|
], |
|
] |
|
|
|
|
|
def select_model(model_name): |
|
global processor, model |
|
processor = AutoProcessor.from_pretrained(model_name) |
|
model = AutoModel.from_pretrained(model_name) |
|
|
|
|
|
def predict(youtube_url_or_file_path, labels_text): |
|
|
|
if youtube_url_or_file_path.startswith("http"): |
|
video_path = download_youtube_video(youtube_url_or_file_path) |
|
else: |
|
video_path = youtube_url_or_file_path |
|
|
|
|
|
num_total_frames = get_num_total_frames(video_path) |
|
num_model_input_frames = model.config.vision_config.num_frames |
|
if num_total_frames < FRAME_SAMPLING_RATE * num_model_input_frames: |
|
frame_sampling_rate = num_total_frames // num_model_input_frames |
|
else: |
|
frame_sampling_rate = FRAME_SAMPLING_RATE |
|
|
|
labels = labels_text.split(",") |
|
|
|
frames = sample_frames_from_video_file( |
|
video_path, num_model_input_frames, frame_sampling_rate |
|
) |
|
gif_path = convert_frames_to_gif(frames, save_path="video.gif") |
|
|
|
inputs = processor( |
|
text=labels, videos=list(frames), return_tensors="pt", padding=True |
|
) |
|
|
|
with torch.no_grad(): |
|
outputs = model(**inputs) |
|
|
|
probs = outputs.logits_per_video[0].softmax(dim=-1).cpu().numpy() |
|
label_to_prob = {} |
|
for ind, label in enumerate(labels): |
|
label_to_prob[label] = float(probs[ind]) |
|
|
|
return label_to_prob, gif_path |
|
|
|
|
|
app = gr.Blocks() |
|
with app: |
|
gr.Markdown( |
|
"# **<p align='center'>Zero-shot Video Classification with 🤗 Transformers</p>**" |
|
) |
|
gr.Markdown( |
|
""" |
|
<p style='text-align: center'> |
|
Follow me for more! |
|
<br> <a href='https://twitter.com/fcakyon' target='_blank'>twitter</a> | <a href='https://github.com/fcakyon' target='_blank'>github</a> | <a href='https://www.linkedin.com/in/fcakyon/' target='_blank'>linkedin</a> | <a href='https://fcakyon.medium.com/' target='_blank'>medium</a> |
|
</p> |
|
""" |
|
) |
|
|
|
with gr.Row(): |
|
with gr.Column(): |
|
model_names_dropdown = gr.Dropdown( |
|
choices=VALID_ZEROSHOT_VIDEOCLASSIFICATION_MODELS, |
|
label="Model:", |
|
show_label=True, |
|
value=DEFAULT_MODEL, |
|
) |
|
model_names_dropdown.change(fn=select_model, inputs=model_names_dropdown) |
|
with gr.Tab(label="Youtube URL"): |
|
gr.Markdown( |
|
"### **Provide a Youtube video URL and a list of labels separated by commas**" |
|
) |
|
youtube_url = gr.Textbox(label="Youtube URL:", show_label=True) |
|
youtube_url_labels_text = gr.Textbox( |
|
label="Labels Text:", show_label=True |
|
) |
|
youtube_url_predict_btn = gr.Button(value="Predict") |
|
with gr.Tab(label="Local File"): |
|
gr.Markdown( |
|
"### **Upload a video file and provide a list of labels separated by commas**" |
|
) |
|
video_file = gr.Video(label="Video File:", show_label=True) |
|
local_video_labels_text = gr.Textbox( |
|
label="Labels Text:", show_label=True |
|
) |
|
local_video_predict_btn = gr.Button(value="Predict") |
|
with gr.Column(): |
|
video_gif = gr.Image( |
|
label="Input Clip", |
|
show_label=True, |
|
) |
|
with gr.Column(): |
|
predictions = gr.Label(label="Predictions:", show_label=True) |
|
|
|
gr.Markdown("**Examples:**") |
|
gr.Examples( |
|
examples, |
|
[youtube_url, youtube_url_labels_text], |
|
[predictions, video_gif], |
|
fn=predict, |
|
cache_examples=True, |
|
) |
|
|
|
youtube_url_predict_btn.click( |
|
predict, |
|
inputs=[youtube_url, youtube_url_labels_text], |
|
outputs=[predictions, video_gif], |
|
) |
|
local_video_predict_btn.click( |
|
predict, |
|
inputs=[video_file, local_video_labels_text], |
|
outputs=[predictions, video_gif], |
|
) |
|
gr.Markdown( |
|
""" |
|
\n Demo created by: <a href=\"https://github.com/fcakyon\">fcakyon</a>. |
|
<br> Based on this <a href=\"https://huggingface.co/docs/transformers/main/model_doc/xclip">HuggingFace model</a>. |
|
""" |
|
) |
|
|
|
app.launch() |
|
|