akhaliq HF staff commited on
Commit
8323f05
1 Parent(s): 25aa87e

Create new file

Browse files
Files changed (1) hide show
  1. app.py +53 -0
app.py ADDED
@@ -0,0 +1,53 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from decord import VideoReader, cpu
2
+ import torch
3
+ import numpy as np
4
+
5
+ from transformers import VideoMAEFeatureExtractor, VideoMAEForVideoClassification
6
+ from huggingface_hub import hf_hub_download
7
+ import gradio as gr
8
+
9
+ np.random.seed(0)
10
+
11
+
12
+ def sample_frame_indices(clip_len, frame_sample_rate, seg_len):
13
+ converted_len = int(clip_len * frame_sample_rate)
14
+ end_idx = np.random.randint(converted_len, seg_len)
15
+ start_idx = end_idx - converted_len
16
+ indices = np.linspace(start_idx, end_idx, num=clip_len)
17
+ indices = np.clip(indices, start_idx, end_idx - 1).astype(np.int64)
18
+ return indices
19
+
20
+
21
+ def inference(file_path):
22
+ # video clip consists of 300 frames (10 seconds at 30 FPS)
23
+ videoreader = VideoReader(file_path, num_threads=1, ctx=cpu(0))
24
+
25
+ # sample 16 frames
26
+ videoreader.seek(0)
27
+ indices = sample_frame_indices(clip_len=16, frame_sample_rate=4, seg_len=len(videoreader))
28
+ video = videoreader.get_batch(indices).asnumpy()
29
+
30
+ feature_extractor = VideoMAEFeatureExtractor.from_pretrained("MCG-NJU/videomae-base-finetuned-kinetics")
31
+ model = VideoMAEForVideoClassification.from_pretrained("MCG-NJU/videomae-base-finetuned-kinetics")
32
+
33
+ inputs = feature_extractor(list(video), return_tensors="pt")
34
+
35
+ with torch.no_grad():
36
+ outputs = model(**inputs)
37
+ logits = outputs.logits
38
+
39
+ # model predicts one of the 400 Kinetics-400 classes
40
+ predicted_label = logits.argmax(-1).item()
41
+ return model.config.id2label[predicted_label]
42
+
43
+ with gr.Blocks() as demo:
44
+ with gr.Row():
45
+ with gr.Column():
46
+ video = gr.Video()
47
+ btn = gr.Button(value="Run")
48
+ with gr.Column():
49
+ label = gr.Textbox(label="Predicted Label")
50
+
51
+ translate_btn.click(inference, inputs=video, outputs=label)
52
+
53
+ demo.launch()