2nzi commited on
Commit
f4281e2
1 Parent(s): d613e03
Files changed (2) hide show
  1. app.py +68 -36
  2. requirements.txt +0 -0
app.py CHANGED
@@ -3,7 +3,8 @@ import av
3
  import torch
4
  from transformers import AutoImageProcessor, AutoModelForVideoClassification
5
  import streamlit as st
6
-
 
7
 
8
  def read_video_pyav(container, indices):
9
  '''
@@ -43,61 +44,92 @@ def sample_frame_indices(clip_len, frame_sample_rate, seg_len):
43
  indices = np.clip(indices, start_idx, end_idx - 1).astype(np.int64)
44
  return indices
45
 
46
- # def sample_frame_indices2(clip_len, frame_sample_rate, seg_len):
47
- # '''
48
- # Description
49
- # Args:
50
- # Returns:
51
- # indices (`List[int]`): List of sampled frame indices
52
- # '''
53
- # return
54
-
55
- def classify(file):
 
 
 
56
  container = av.open(file)
57
 
58
  # sample 16 frames
59
  indices = sample_frame_indices(clip_len=16, frame_sample_rate=4, seg_len=container.streams.video[0].frames)
60
  video = read_video_pyav(container, indices)
61
 
62
- if container.streams.video[0].frames < 16:
63
- return 'Video trop courte'
64
-
65
  inputs = image_processor(list(video), return_tensors="pt")
66
 
67
  with torch.no_grad():
68
- outputs = model(**inputs)
69
  logits = outputs.logits
70
 
71
- # model predicts one of the 400 Kinetics-400 classes
72
  predicted_label = logits.argmax(-1).item()
73
- print(model.config.id2label[predicted_label])
74
-
75
- return model.config.id2label[predicted_label]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
76
 
77
 
78
- model_ckpt = '2nzi/videomae-surf-analytics'
79
- # pipe = pipeline("video-classification", model="2nzi/videomae-surf-analytics")
80
- image_processor = AutoImageProcessor.from_pretrained(model_ckpt)
81
- model = AutoModelForVideoClassification.from_pretrained(model_ckpt)
82
 
83
 
84
 
 
 
 
 
 
 
 
 
 
85
 
 
86
 
87
- st.subheader("Surf Analytics")
 
 
 
 
 
88
 
89
- st.markdown("""
90
- Bienvenue sur le projet Surf Analytics réalisé par Walid, Guillaume, Valentine, et Antoine.
91
-
92
- <a href="https://github.com/2nzi/M09-FinalProject-Surf-Analytics" style="text-decoration: none;">@Surf-Analytics-Github</a>.
93
- """, unsafe_allow_html=True)
94
 
95
- st.title("Surf Maneuver Classification")
 
 
96
 
97
- uploaded_file = st.file_uploader("Upload a video file", type=["mp4", "avi", "mov"])
 
 
98
 
99
- if uploaded_file is not None:
100
- video_bytes = uploaded_file.read()
101
- st.video(video_bytes)
102
- predicted_label = classify(uploaded_file)
103
- st.success(f"Predicted Label: {predicted_label}")
 
3
  import torch
4
  from transformers import AutoImageProcessor, AutoModelForVideoClassification
5
  import streamlit as st
6
+ import torch.nn as nn
7
+ from streamlit_navigation_bar import st_navbar
8
 
9
  def read_video_pyav(container, indices):
10
  '''
 
44
  indices = np.clip(indices, start_idx, end_idx - 1).astype(np.int64)
45
  return indices
46
 
47
+ def victoire():
48
+ gif_url = "https://i.postimg.cc/rDp7xRJY/Happy-Birthday-Confetti.gif"
49
+ html_gif = f"""
50
+ <div style="display: flex; justify-content: center; align-items: center;">
51
+ <img src="{gif_url}" height="auto" style="margin: 0px;">
52
+ <img src="{gif_url}" height="auto" style="margin: 0px;">
53
+ <img src="{gif_url}" height="auto" style="margin: 0px;">
54
+ <img src="{gif_url}" height="auto" style="margin: 0px;">
55
+ </div>
56
+ """
57
+ st.markdown(html_gif, unsafe_allow_html=True)
58
+
59
+ def classify(model_maneuver,model_Surf_notSurf,file):
60
  container = av.open(file)
61
 
62
  # sample 16 frames
63
  indices = sample_frame_indices(clip_len=16, frame_sample_rate=4, seg_len=container.streams.video[0].frames)
64
  video = read_video_pyav(container, indices)
65
 
 
 
 
66
  inputs = image_processor(list(video), return_tensors="pt")
67
 
68
  with torch.no_grad():
69
+ outputs = model_Surf_notSurf(**inputs)
70
  logits = outputs.logits
71
 
 
72
  predicted_label = logits.argmax(-1).item()
73
+ print(model_Surf_notSurf.config.id2label[predicted_label])
74
+
75
+ if model_Surf_notSurf.config.id2label[predicted_label]!='Surfing':
76
+ return model_Surf_notSurf.config.id2label[predicted_label]
77
+ else:
78
+ with torch.no_grad():
79
+ outputs = model_maneuver(**inputs)
80
+ logits = outputs.logits
81
+
82
+ predicted_label = logits.argmax(-1).item()
83
+ print(model_maneuver.config.id2label[predicted_label])
84
+ # st.write(f'Les labels: {model_maneuver.config.id2label}')
85
+ # st.write(f'répartiton des probilités {logits}')
86
+ # st.write(f'répartiton des probilités {nn.Softmax(dim=-1)(logits)}')
87
+
88
+ return model_maneuver.config.id2label[predicted_label]
89
+
90
+
91
+ model_maneuver = '2nzi/videomae-surf-analytics'
92
+ model_Surf_notSurf = '2nzi/videomae-surf-analytics-surfNOTsurf'
93
+ image_processor = AutoImageProcessor.from_pretrained(model_maneuver)
94
+ model_maneuver = AutoModelForVideoClassification.from_pretrained(model_maneuver)
95
+ model_Surf_notSurf = AutoModelForVideoClassification.from_pretrained(model_Surf_notSurf)
96
+
97
+
98
 
99
 
100
+ # Define the navigation bar and its pages
101
+ page = st_navbar(["Home", "Documentation", "Examples", "About Us"])
 
 
102
 
103
 
104
 
105
+ # Main application code
106
+ if page == "Home":
107
+ st.subheader("Surf Analytics")
108
+ st.markdown("""
109
+ Bienvenue sur le projet Surf Analytics réalisé par Walid, Guillaume, Valentine, et Antoine.
110
+
111
+ <a href="https://github.com/2nzi/M09-FinalProject-Surf-Analytics" style="text-decoration: none;">@Surf-Analytics-Github</a>.
112
+ """, unsafe_allow_html=True)
113
+ st.title("Surf Maneuver Classification")
114
 
115
+ uploaded_file = st.file_uploader("Upload a video file", type=["mp4"])
116
 
117
+ if uploaded_file is not None:
118
+ video_bytes = uploaded_file.read()
119
+ st.video(video_bytes)
120
+ predicted_label = classify(model_maneuver, model_Surf_notSurf, uploaded_file)
121
+ st.success(f"Predicted Label: {predicted_label}")
122
+ victoire()
123
 
 
 
 
 
 
124
 
125
+ elif page == "Documentation":
126
+ st.title("Documentation")
127
+ st.markdown("Here you can add your documentation content.")
128
 
129
+ elif page == "Examples":
130
+ st.title("Examples")
131
+ st.markdown("Here you can add examples related to your project.")
132
 
133
+ elif page == "About Us":
134
+ st.title("About")
135
+ st.markdown("Here you can add information about the project and the team.")
 
 
requirements.txt CHANGED
Binary files a/requirements.txt and b/requirements.txt differ