khaldii commited on
Commit
48652c4
1 Parent(s): 8b1c801

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +117 -0
app.py ADDED
@@ -0,0 +1,117 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import av
3
+ import torch
4
+ from transformers import AutoImageProcessor, AutoModelForVideoClassification
5
+ import streamlit as st
6
+ import torch.nn as nn
7
+
8
+
9
+ def read_video_pyav(container, indices):
10
+ '''
11
+ Decode the video with PyAV decoder.
12
+ Args:
13
+ container (`av.container.input.InputContainer`): PyAV container.
14
+ indices (`List[int]`): List of frame indices to decode.
15
+ Returns:
16
+ result (np.ndarray): np array of decoded frames of shape (num_frames, height, width, 3).
17
+ '''
18
+ frames = []
19
+ container.seek(0)
20
+ start_index = indices[0]
21
+ end_index = indices[-1]
22
+ for i, frame in enumerate(container.decode(video=0)):
23
+ if i > end_index:
24
+ break
25
+ if i >= start_index and i in indices:
26
+ frames.append(frame)
27
+ return np.stack([x.to_ndarray(format="rgb24") for x in frames])
28
+
29
+
30
+ def sample_frame_indices(clip_len, frame_sample_rate, seg_len):
31
+ '''
32
+ Sample a given number of frame indices from the video.
33
+ Args:
34
+ clip_len (`int`): Total number of frames to sample.
35
+ frame_sample_rate (`int`): Sample every n-th frame.
36
+ seg_len (`int`): Maximum allowed index of sample's last frame.
37
+ Returns:
38
+ indices (`List[int]`): List of sampled frame indices
39
+ '''
40
+ converted_len = int(clip_len * frame_sample_rate)
41
+ end_idx = np.random.randint(converted_len, seg_len)
42
+ start_idx = end_idx - converted_len
43
+ indices = np.linspace(start_idx, end_idx, num=clip_len)
44
+ indices = np.clip(indices, start_idx, end_idx - 1).astype(np.int64)
45
+ return indices
46
+
47
+ # def sample_frame_indices2(clip_len, frame_sample_rate, seg_len):
48
+ # '''
49
+ # Description
50
+ # Args:
51
+ # Returns:
52
+ # indices (`List[int]`): List of sampled frame indices
53
+ # '''
54
+ # return
55
+
56
+
57
+
58
+ def classify(model_maneuver,model_Surf_notSurf,file):
59
+ container = av.open(file)
60
+
61
+ # sample 16 frames
62
+ indices = sample_frame_indices(clip_len=16, frame_sample_rate=4, seg_len=container.streams.video[0].frames)
63
+ video = read_video_pyav(container, indices)
64
+
65
+ inputs = image_processor(list(video), return_tensors="pt")
66
+
67
+ with torch.no_grad():
68
+ outputs = model_Surf_notSurf(**inputs)
69
+ logits = outputs.logits
70
+
71
+ predicted_label = logits.argmax(-1).item()
72
+ print(model_Surf_notSurf.config.id2label[predicted_label])
73
+
74
+ if model_Surf_notSurf.config.id2label[predicted_label]!='Surfing':
75
+ return model_Surf_notSurf.config.id2label[predicted_label]
76
+ else:
77
+ with torch.no_grad():
78
+ outputs = model_maneuver(**inputs)
79
+ logits = outputs.logits
80
+
81
+ predicted_label = logits.argmax(-1).item()
82
+ print(model_maneuver.config.id2label[predicted_label])
83
+ # st.write(f'Les labels: {model_maneuver.config.id2label}')
84
+ # st.write(f'répartiton des probilités {logits}')
85
+ # st.write(f'répartiton des probilités {nn.Softmax(dim=-1)(logits)}')
86
+
87
+ return model_maneuver.config.id2label[predicted_label]
88
+
89
+
90
+ model_maneuver = '2nzi/videomae-surf-analytics'
91
+ model_Surf_notSurf = '2nzi/videomae-surf-analytics-surfNOTsurf'
92
+ # pipe = pipeline("video-classification", model="2nzi/videomae-surf-analytics")
93
+ image_processor = AutoImageProcessor.from_pretrained(model_maneuver)
94
+ model_maneuver = AutoModelForVideoClassification.from_pretrained(model_maneuver)
95
+ model_Surf_notSurf = AutoModelForVideoClassification.from_pretrained(model_Surf_notSurf)
96
+
97
+
98
+
99
+
100
+
101
+ st.subheader("Surf Analytics")
102
+
103
+ st.markdown("""
104
+ Bienvenue sur le projet Surf Analytics réalisé par Walid, Guillaume, Valentine, et Antoine.
105
+
106
+ <a href="https://github.com/2nzi/M09-FinalProject-Surf-Analytics" style="text-decoration: none;">@Surf-Analytics-Github</a>.
107
+ """, unsafe_allow_html=True)
108
+
109
+ st.title("Surf Maneuver Classification")
110
+
111
+ uploaded_file = st.file_uploader("Upload a video file", type=["mp4", "avi", "mov"])
112
+
113
+ if uploaded_file is not None:
114
+ video_bytes = uploaded_file.read()
115
+ st.video(video_bytes)
116
+ predicted_label = classify(model_maneuver,model_Surf_notSurf,uploaded_file)
117
+ st.success(f"Predicted Label: {predicted_label}")