adpai commited on
Commit
9303daf
1 Parent(s): c84be09

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +72 -2
app.py CHANGED
@@ -1,3 +1,73 @@
1
- import gradio as gr
 
 
 
 
 
2
 
3
- gr.Interface.load("models/microsoft/xclip-base-patch32").launch()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import av
2
+ import torch
3
+ import numpy as np
4
+ from fastapi import FastAPI, UploadFile, File
5
+ from transformers import AutoProcessor, AutoModel
6
+ from huggingface_hub import hf_hub_download
7
 
8
+ app = FastAPI()
9
+
10
+ np.random.seed(0)
11
+
12
+ def read_video_pyav(container, indices):
13
+ '''
14
+ Decode the video with PyAV decoder.
15
+ Args:
16
+ container (`av.container.input.InputContainer`): PyAV container.
17
+ indices (`List[int]`): List of frame indices to decode.
18
+ Returns:
19
+ result (np.ndarray): np array of decoded frames of shape (num_frames, height, width, 3).
20
+ '''
21
+ frames = []
22
+ container.seek(0)
23
+ start_index = indices[0]
24
+ end_index = indices[-1]
25
+ for i, frame in enumerate(container.decode(video=0)):
26
+ if i > end_index:
27
+ break
28
+ if i >= start_index and i in indices:
29
+ frames.append(frame)
30
+ return np.stack([x.to_ndarray(format="rgb24") for x in frames])
31
+
32
+ def sample_frame_indices(clip_len, frame_sample_rate, seg_len):
33
+ '''
34
+ Sample a given number of frame indices from the video.
35
+ Args:
36
+ clip_len (`int`): Total number of frames to sample.
37
+ frame_sample_rate (`int`): Sample every n-th frame.
38
+ seg_len (`int`): Maximum allowed index of sample's last frame.
39
+ Returns:
40
+ indices (`List[int]`): List of sampled frame indices
41
+ '''
42
+ converted_len = int(clip_len * frame_sample_rate)
43
+ end_idx = np.random.randint(converted_len, seg_len)
44
+ start_idx = end_idx - converted_len
45
+ indices = np.linspace(start_idx, end_idx, num=clip_len)
46
+ indices = np.clip(indices, start_idx, end_idx - 1).astype(np.int64)
47
+ return indices
48
+
49
+ processor = AutoProcessor.from_pretrained("microsoft/xclip-base-patch32")
50
+ model = AutoModel.from_pretrained("microsoft/xclip-base-patch32")
51
+
52
+ @app.post("/classify_video/")
53
+ async def classify_video(file: UploadFile):
54
+ file_bytes = await file.read()
55
+
56
+ container = av.open(file_bytes)
57
+ indices = sample_frame_indices(clip_len=8, frame_sample_rate=1, seg_len=container.streams.video[0].frames)
58
+ video = read_video_pyav(container, indices)
59
+
60
+ inputs = processor(
61
+ text=["playing sports", "eating spaghetti", "go shopping"],
62
+ videos=[video], # Changed list(video) to [video] to avoid error
63
+ return_tensors="pt",
64
+ padding=True,
65
+ )
66
+
67
+ with torch.no_grad():
68
+ outputs = model(**inputs)
69
+
70
+ logits_per_video = outputs.logits_per_video
71
+ probs = logits_per_video.softmax(dim=1)
72
+
73
+ return {"classification_probabilities": probs.tolist()}