2nzi commited on
Commit
d613e03
1 Parent(s): 64a2e80

upload files

Browse files
Files changed (2) hide show
  1. app.py +103 -0
  2. requirements.txt +0 -0
app.py ADDED
@@ -0,0 +1,103 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import av
3
+ import torch
4
+ from transformers import AutoImageProcessor, AutoModelForVideoClassification
5
+ import streamlit as st
6
+
7
+
8
+ def read_video_pyav(container, indices):
9
+ '''
10
+ Decode the video with PyAV decoder.
11
+ Args:
12
+ container (`av.container.input.InputContainer`): PyAV container.
13
+ indices (`List[int]`): List of frame indices to decode.
14
+ Returns:
15
+ result (np.ndarray): np array of decoded frames of shape (num_frames, height, width, 3).
16
+ '''
17
+ frames = []
18
+ container.seek(0)
19
+ start_index = indices[0]
20
+ end_index = indices[-1]
21
+ for i, frame in enumerate(container.decode(video=0)):
22
+ if i > end_index:
23
+ break
24
+ if i >= start_index and i in indices:
25
+ frames.append(frame)
26
+ return np.stack([x.to_ndarray(format="rgb24") for x in frames])
27
+
28
+
29
+ def sample_frame_indices(clip_len, frame_sample_rate, seg_len):
30
+ '''
31
+ Sample a given number of frame indices from the video.
32
+ Args:
33
+ clip_len (`int`): Total number of frames to sample.
34
+ frame_sample_rate (`int`): Sample every n-th frame.
35
+ seg_len (`int`): Maximum allowed index of sample's last frame.
36
+ Returns:
37
+ indices (`List[int]`): List of sampled frame indices
38
+ '''
39
+ converted_len = int(clip_len * frame_sample_rate)
40
+ end_idx = np.random.randint(converted_len, seg_len)
41
+ start_idx = end_idx - converted_len
42
+ indices = np.linspace(start_idx, end_idx, num=clip_len)
43
+ indices = np.clip(indices, start_idx, end_idx - 1).astype(np.int64)
44
+ return indices
45
+
46
+ # def sample_frame_indices2(clip_len, frame_sample_rate, seg_len):
47
+ # '''
48
+ # Description
49
+ # Args:
50
+ # Returns:
51
+ # indices (`List[int]`): List of sampled frame indices
52
+ # '''
53
+ # return
54
+
55
+ def classify(file):
56
+ container = av.open(file)
57
+
58
+ # sample 16 frames
59
+ indices = sample_frame_indices(clip_len=16, frame_sample_rate=4, seg_len=container.streams.video[0].frames)
60
+ video = read_video_pyav(container, indices)
61
+
62
+ if container.streams.video[0].frames < 16:
63
+ return 'Video trop courte'
64
+
65
+ inputs = image_processor(list(video), return_tensors="pt")
66
+
67
+ with torch.no_grad():
68
+ outputs = model(**inputs)
69
+ logits = outputs.logits
70
+
71
+ # model predicts one of the 400 Kinetics-400 classes
72
+ predicted_label = logits.argmax(-1).item()
73
+ print(model.config.id2label[predicted_label])
74
+
75
+ return model.config.id2label[predicted_label]
76
+
77
+
78
+ model_ckpt = '2nzi/videomae-surf-analytics'
79
+ # pipe = pipeline("video-classification", model="2nzi/videomae-surf-analytics")
80
+ image_processor = AutoImageProcessor.from_pretrained(model_ckpt)
81
+ model = AutoModelForVideoClassification.from_pretrained(model_ckpt)
82
+
83
+
84
+
85
+
86
+
87
+ st.subheader("Surf Analytics")
88
+
89
+ st.markdown("""
90
+ Bienvenue sur le projet Surf Analytics réalisé par Walid, Guillaume, Valentine, et Antoine.
91
+
92
+ <a href="https://github.com/2nzi/M09-FinalProject-Surf-Analytics" style="text-decoration: none;">@Surf-Analytics-Github</a>.
93
+ """, unsafe_allow_html=True)
94
+
95
+ st.title("Surf Maneuver Classification")
96
+
97
+ uploaded_file = st.file_uploader("Upload a video file", type=["mp4", "avi", "mov"])
98
+
99
+ if uploaded_file is not None:
100
+ video_bytes = uploaded_file.read()
101
+ st.video(video_bytes)
102
+ predicted_label = classify(uploaded_file)
103
+ st.success(f"Predicted Label: {predicted_label}")
requirements.txt ADDED
Binary file (118 Bytes). View file