Spaces:
Runtime error
Runtime error
| import os | |
| import gradio as gr | |
| from utils import ( | |
| create_gif_from_video_file, | |
| download_youtube_video, | |
| get_num_total_frames, | |
| ) | |
| from transformers import pipeline | |
| from huggingface_hub import HfApi, ModelSearchArguments, ModelFilter | |
| FRAME_SAMPLING_RATE = 4 | |
| DEFAULT_MODEL = "facebook/timesformer-base-finetuned-k400" | |
| VALID_VIDEOCLASSIFICATION_MODELS = [ | |
| "MCG-NJU/videomae-large-finetuned-kinetics", | |
| "facebook/timesformer-base-finetuned-k400", | |
| "fcakyon/timesformer-large-finetuned-k400", | |
| "MCG-NJU/videomae-base-finetuned-kinetics", | |
| "facebook/timesformer-base-finetuned-k600", | |
| "fcakyon/timesformer-large-finetuned-k600", | |
| "facebook/timesformer-hr-finetuned-k400", | |
| "facebook/timesformer-hr-finetuned-k600", | |
| "facebook/timesformer-base-finetuned-ssv2", | |
| "fcakyon/timesformer-large-finetuned-ssv2", | |
| "facebook/timesformer-hr-finetuned-ssv2", | |
| "MCG-NJU/videomae-base-finetuned-ssv2", | |
| "MCG-NJU/videomae-base-short-finetuned-kinetics", | |
| "MCG-NJU/videomae-base-short-ssv2", | |
| "MCG-NJU/videomae-base-short-finetuned-ssv2", | |
| "sayakpaul/videomae-base-finetuned-ucf101-subset", | |
| "nateraw/videomae-base-finetuned-ucf101", | |
| "MCG-NJU/videomae-base-ssv2", | |
| "zahrav/videomae-base-finetuned-ucf101-subset", | |
| ] | |
| pipe = pipeline( | |
| task="video-classification", | |
| model=DEFAULT_MODEL, | |
| top_k=5, | |
| frame_sampling_rate=FRAME_SAMPLING_RATE, | |
| ) | |
| examples = [ | |
| ["https://www.youtube.com/watch?v=huAJ9dC5lmI"], | |
| ["https://www.youtube.com/watch?v=wvcWt6u5HTg"], | |
| ["https://www.youtube.com/watch?v=-3kZSi5qjRM"], | |
| ["https://www.youtube.com/watch?v=-6usjfP8hys"], | |
| ["https://www.youtube.com/watch?v=BDHub0gBGtc"], | |
| ["https://www.youtube.com/watch?v=B9ea7YyCP6E"], | |
| ["https://www.youtube.com/watch?v=BBkpaeJBKmk"], | |
| ["https://www.youtube.com/watch?v=BBqU8Apee_g"], | |
| ["https://www.youtube.com/watch?v=B8OdMwVwyXc"], | |
| ["https://www.youtube.com/watch?v=I7cwq6_4QtM"], | |
| ["https://www.youtube.com/watch?v=Z0mJDXpNhYA"], | |
| ["https://www.youtube.com/watch?v=QkQQjFGnZlg"], | |
| ["https://www.youtube.com/watch?v=IQaoRUQif14"], | |
| ] | |
| def get_video_model_names(): | |
| model_args = ModelSearchArguments() | |
| filter = ModelFilter( | |
| task=model_args.pipeline_tag.VideoClassification, | |
| library=model_args.library.Transformers, | |
| ) | |
| api = HfApi() | |
| video_models = list( | |
| iter(api.list_models(filter=filter, sort="downloads", direction=-1)) | |
| ) | |
| video_models = [video_model.id for video_model in video_models] | |
| return video_models | |
| def select_model(model_name): | |
| global pipe | |
| pipe = pipeline( | |
| task="video-classification", | |
| model=model_name, | |
| top_k=5, | |
| frame_sampling_rate=FRAME_SAMPLING_RATE, | |
| ) | |
| def predict(youtube_url_or_file_path): | |
| if youtube_url_or_file_path.startswith("http"): | |
| video_path = download_youtube_video(youtube_url_or_file_path) | |
| else: | |
| video_path = youtube_url_or_file_path | |
| # rearrange sampling rate based on video length and model input length | |
| num_total_frames = get_num_total_frames(video_path) | |
| num_model_input_frames = pipe.model.config.num_frames | |
| if num_total_frames < FRAME_SAMPLING_RATE * num_model_input_frames: | |
| frame_sampling_rate = num_total_frames // num_model_input_frames | |
| else: | |
| frame_sampling_rate = FRAME_SAMPLING_RATE | |
| gif_path = create_gif_from_video_file( | |
| video_path, frame_sampling_rate=frame_sampling_rate, save_path="video.gif" | |
| ) | |
| # run inference | |
| results = pipe(videos=video_path, frame_sampling_rate=frame_sampling_rate) | |
| os.remove(video_path) | |
| label_to_score = {result["label"]: result["score"] for result in results} | |
| return label_to_score, gif_path | |
| # Create the Gradio app | |
| app = gr.Blocks() | |
| with app: | |
| # Title and description | |
| gr.Markdown("# **<p align='center'>Video Classification with 🤗 Transformers</p>**") | |
| gr.Markdown( | |
| """ | |
| <p style='text-align: center'> | |
| Perform video classification with <a href='https://huggingface.co/models?pipeline_tag=video-classification&library=transformers' target='_blank'>HuggingFace Transformers video models</a>. | |
| <br> For zero-shot classification, you can use the <a href='https://huggingface.co/spaces/fcakyon/zero-shot-video-classification' target='_blank'>zero-shot classification demo</a>. | |
| </p> | |
| """ | |
| ) | |
| # Model Selection and Input | |
| gr.Label("Model:") | |
| model_names_dropdown = gr.Dropdown( | |
| choices=VALID_VIDEOCLASSIFICATION_MODELS, | |
| value=DEFAULT_MODEL, | |
| ) | |
| model_names_dropdown.change(fn=select_model, inputs=model_names_dropdown) | |
| # Tabs for Youtube URL and Local File | |
| with gr.Row(): | |
| with gr.Column(): | |
| gr.Markdown("### **Provide a Youtube video URL**") | |
| youtube_url = gr.Textbox(label="Youtube URL:") | |
| youtube_url_predict_btn = gr.Button(value="Predict") | |
| with gr.Column(): | |
| gr.Markdown("### **Upload a video file**") | |
| video_file = gr.Video(label="Video File:") | |
| local_video_predict_btn = gr.Button(value="Predict") | |
| # Display Input Clip | |
| video_gif = gr.Image(label="Input Clip") | |
| # Display Predictions | |
| predictions = gr.Label(label="Predictions:", num_top_classes=5) | |
| # Examples and Prediction Buttons | |
| gr.Markdown("**Examples:**") | |
| gr.Examples( | |
| examples, | |
| youtube_url, | |
| [predictions, video_gif], | |
| fn=predict, | |
| cache_examples=True, | |
| ) | |
| # Click handlers for prediction buttons | |
| youtube_url_predict_btn.click( | |
| predict, inputs=youtube_url, outputs=[predictions, video_gif] | |
| ) | |
| local_video_predict_btn.click( | |
| predict, inputs=video_file, outputs=[predictions, video_gif] | |
| ) | |
| # Launch the Gradio app | |
| app.launch() | |