khaldii commited on
Commit
e6aa344
1 Parent(s): 1fd1426

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +97 -0
app.py ADDED
@@ -0,0 +1,97 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import av
3
+ import torch
4
+ # from transformers.models.auto import AutoImageProcessor, AutoModelForVideoClassification
5
+ from transformers import AutoImageProcessor, AutoModelForVideoClassification
6
+ import streamlit as st
7
+
8
+
9
+ def read_video_pyav(container, indices):
10
+ '''
11
+ Decode the video with PyAV decoder.
12
+ Args:
13
+ container (`av.container.input.InputContainer`): PyAV container.
14
+ indices (`List[int]`): List of frame indices to decode.
15
+ Returns:
16
+ result (np.ndarray): np array of decoded frames of shape (num_frames, height, width, 3).
17
+ '''
18
+ frames = []
19
+ container.seek(0)
20
+ start_index = indices[0]
21
+ end_index = indices[-1]
22
+ for i, frame in enumerate(container.decode(video=0)):
23
+ if i > end_index:
24
+ break
25
+ if i >= start_index and i in indices:
26
+ frames.append(frame)
27
+ return np.stack([x.to_ndarray(format="rgb24") for x in frames])
28
+
29
+
30
+ def sample_frame_indices(clip_len, frame_sample_rate, seg_len):
31
+ '''
32
+ Sample a given number of frame indices from the video.
33
+ Args:
34
+ clip_len (`int`): Total number of frames to sample.
35
+ frame_sample_rate (`int`): Sample every n-th frame.
36
+ seg_len (`int`): Maximum allowed index of sample's last frame.
37
+ Returns:
38
+ indices (`List[int]`): List of sampled frame indices
39
+ '''
40
+ converted_len = int(clip_len * frame_sample_rate)
41
+ end_idx = np.random.randint(converted_len, seg_len)
42
+ start_idx = end_idx - converted_len
43
+ indices = np.linspace(start_idx, end_idx, num=clip_len)
44
+ indices = np.clip(indices, start_idx, end_idx - 1).astype(np.int64)
45
+ return indices
46
+
47
+
48
+
49
+ def classify(file):
50
+ container = av.open(file)
51
+
52
+ # sample 16 frames
53
+ indices = sample_frame_indices(clip_len=16, frame_sample_rate=4, seg_len=container.streams.video[0].frames)
54
+ video = read_video_pyav(container, indices)
55
+
56
+ if container.streams.video[0].frames < 16:
57
+ return 'Video trop courte'
58
+
59
+ inputs = image_processor(list(video), return_tensors="pt")
60
+
61
+ with torch.no_grad():
62
+ outputs = model(**inputs)
63
+ logits = outputs.logits
64
+
65
+ # model predicts one of the 400 Kinetics-400 classes
66
+ predicted_label = logits.argmax(-1).item()
67
+ print(model.config.id2label[predicted_label])
68
+
69
+ return model.config.id2label[predicted_label]
70
+
71
+
72
+ model_ckpt = '2nzi/videomae-surf-analytics'
73
+ # pipe = pipeline("video-classification", model="2nzi/videomae-surf-analytics")
74
+ image_processor = AutoImageProcessor.from_pretrained(model_ckpt)
75
+ model = AutoModelForVideoClassification.from_pretrained(model_ckpt)
76
+
77
+
78
+
79
+
80
+
81
+ st.subheader("Surf Analytics")
82
+
83
+ st.markdown("""
84
+ Bienvenue sur le projet Surf Analytics réalisé par Walid, Guillaume, Valentine, et Antoine.
85
+
86
+ <a href="https://github.com/2nzi/M09-FinalProject-Surf-Analytics" style="text-decoration: none;">@Surf-Analytics-Github</a>.
87
+ """, unsafe_allow_html=True)
88
+
89
+ st.title("Surf Maneuver Classification")
90
+
91
+ uploaded_file = st.file_uploader("Upload a video file", type=["mp4", "avi", "mov"])
92
+
93
+ if uploaded_file is not None:
94
+ video_bytes = uploaded_file.read()
95
+ st.video(video_bytes)
96
+ predicted_label = classify(uploaded_file)
97
+ st.success(f"Predicted Label: {predicted_label}")