piperod91 commited on
Commit
690e199
·
1 Parent(s): 169c8af

fixing border and adding logo

Browse files
Files changed (3) hide show
  1. app.py +66 -19
  2. inference.py +1 -1
  3. metrics.py +2 -0
app.py CHANGED
@@ -33,7 +33,7 @@ import pathlib
33
  import multiprocessing as mp
34
  from time import time
35
 
36
- if not os.path.exists('videos_example'):
37
  REPO_ID='SharkSpace/videos_examples'
38
  snapshot_download(repo_id=REPO_ID, token=os.environ.get('SHARK_MODEL'),repo_type='dataset',local_dir='videos_example')
39
 
@@ -65,6 +65,46 @@ def overlay_text_on_image(image, text_list, font=cv2.FONT_HERSHEY_SIMPLEX, font_
65
  cv2.putText(image, line, (image.shape[1] - text_width - margin, y), font, font_size, color, font_thickness, lineType=cv2.LINE_AA)
66
  return image
67
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
68
  def draw_cockpit(frame, top_pred,cnt):
69
  # Bullet points:
70
  high_danger_color = (255,0,0)
@@ -80,17 +120,21 @@ def draw_cockpit(frame, top_pred,cnt):
80
  strings = [shark_sighted, human_sighted, shark_size_estimate, shark_weight_estimate, danger_level]
81
  relative = max(frame.shape[0],frame.shape[1])
82
  if top_pred['shark_sighted'] and top_pred['dangerous_dist'] and cnt%2 == 0:
83
- relative = max(frame.shape[0],frame.shape[1])
84
  frame = add_border(frame, color=high_danger_color, thickness=int(relative*0.025))
 
85
  elif top_pred['shark_sighted'] and not top_pred['dangerous_dist'] and cnt%2 == 0:
86
- relative = max(frame.shape[0],frame.shape[1])
87
  frame = add_border(frame, color=low_danger_color, thickness=int(relative*0.025))
 
 
 
 
 
88
  overlay_text_on_image(frame, strings, font=cv2.FONT_HERSHEY_SIMPLEX, font_size=relative*0.0007, font_thickness=1, margin=int(relative*0.05), color=(255, 255, 255))
89
  return frame
90
 
91
 
92
 
93
- def process_video(input_video, out_fps = 'auto', skip_frames = 7):
94
  cap = cv2.VideoCapture(input_video)
95
 
96
  output_path = "output.mp4"
@@ -110,6 +154,8 @@ def process_video(input_video, out_fps = 'auto', skip_frames = 7):
110
  cnt = 0
111
 
112
  while iterating:
 
 
113
  if (cnt % skip_frames) == 0:
114
  print('starting Frame: ', cnt)
115
  # flip frame vertically
@@ -124,38 +170,39 @@ def process_video(input_video, out_fps = 'auto', skip_frames = 7):
124
 
125
 
126
  #frame = cv2.resize(frame, (int(width), int(height)))
 
127
 
128
  if cnt*skip_frames %2==0 and top_pred['shark_sighted']:
129
- #prediction_frame = cv2.resize(prediction_frame, (int(width), int(height)))
130
  frame =prediction_frame
131
-
132
  if top_pred['shark_sighted']:
133
  frame = draw_cockpit(frame, top_pred,cnt*skip_frames)
134
- video.write(cv2.cvtColor(frame, cv2.COLOR_RGB2BGR))
 
 
135
  pred_dashbord = prediction_dashboard(top_pred = top_pred)
136
  #print('sending frame')
137
  print('finalizing frame:',cnt)
138
  print(pred_dashbord.shape)
139
  print(frame.shape)
140
  print(prediction_frame.shape)
141
- yield prediction_frame,frame , None, pred_dashbord
142
- print('overall count ', cnt)
143
  cnt += 1
144
  iterating, frame = cap.read()
145
 
146
  video.release()
147
- yield None, None, output_path, None
148
 
149
  with gr.Blocks(theme=theme) as demo:
150
- with gr.Row().style(equal_height=True,height='25%'):
151
  input_video = gr.Video(label="Input")
152
- processed_frames = gr.Image(label="Shark Engine")
153
- output_video = gr.Video(label="Output Video")
154
- dashboard = gr.Image(label="Dashboard")
155
-
156
- with gr.Row():
157
 
158
- original_frames = gr.Image(label="Original Frame").style( height=650)
 
 
 
159
 
160
  with gr.Row():
161
  paths = sorted(pathlib.Path('videos_example/').rglob('*.mp4'))
@@ -163,8 +210,8 @@ with gr.Blocks(theme=theme) as demo:
163
  examples = gr.Examples(samples, inputs=input_video)
164
  process_video_btn = gr.Button("Process Video")
165
 
166
- process_video_btn.click(process_video, input_video, [processed_frames, original_frames, output_video, dashboard])
167
-
168
  demo.queue()
169
  if os.getenv('SYSTEM') == 'spaces':
170
  demo.launch(width='40%',auth=(os.environ.get('SHARK_USERNAME'), os.environ.get('SHARK_PASSWORD')))
 
33
  import multiprocessing as mp
34
  from time import time
35
 
36
+ if not os.path.exists('videos_example') and not os.getenv('SYSTEM') == 'spaces':
37
  REPO_ID='SharkSpace/videos_examples'
38
  snapshot_download(repo_id=REPO_ID, token=os.environ.get('SHARK_MODEL'),repo_type='dataset',local_dir='videos_example')
39
 
 
65
  cv2.putText(image, line, (image.shape[1] - text_width - margin, y), font, font_size, color, font_thickness, lineType=cv2.LINE_AA)
66
  return image
67
 
68
+ def overlay_logo(frame,logo, position=(10, 10)):
69
+ """
70
+ Overlay a transparent logo (with alpha channel) on a frame.
71
+
72
+ Parameters:
73
+ - frame: The main image/frame to overlay the logo on.
74
+ - logo_path: Path to the logo image.
75
+ - position: (x, y) tuple indicating where the logo starts (top left corner).
76
+ """
77
+ # Load the logo and its alpha channel
78
+ alpha_channel = np.ones(logo.shape[:2], dtype=logo.dtype)
79
+ print(logo.min(),logo.max())
80
+ logo = np.dstack((logo, alpha_channel))
81
+
82
+ indexes = logo[:,:,1]>150
83
+ logo[indexes,3] = 0
84
+ l_channels = cv2.split(logo)
85
+ if len(l_channels) != 4:
86
+ raise ValueError("Logo doesn't have an alpha channel!")
87
+ l_b, l_g, l_r, l_alpha = l_channels
88
+ cv2.imwrite('l_alpha.png',l_alpha*255)
89
+ # Extract regions of interest (ROI) from both images
90
+ roi = frame[position[1]:position[1]+logo.shape[0], position[0]:position[0]+logo.shape[1]]
91
+
92
+ # Blend the logo using the alpha channel
93
+ for channel in range(0, 3):
94
+ roi[:, :, channel] = (l_alpha ) * l_channels[channel] + (1.0 - l_alpha ) * roi[:, :, channel]
95
+
96
+ return frame
97
+
98
+
99
+ def add_danger_symbol_from_image(frame, top_pred):
100
+ relative = max(frame.shape[0],frame.shape[1])
101
+ if top_pred['shark_sighted'] and top_pred['dangerous_dist']:
102
+ # Add the danger symbol
103
+ danger_symbol = cv2.imread('static/danger_symbol.jpeg')
104
+ danger_symbol = cv2.resize(danger_symbol, (int(relative*0.1), int(relative*0.1)), interpolation = cv2.INTER_AREA)[:,:,::-1]
105
+ frame = overlay_logo(frame,danger_symbol, position=(int(relative*0.05), int(relative*0.05)))
106
+ return frame
107
+
108
  def draw_cockpit(frame, top_pred,cnt):
109
  # Bullet points:
110
  high_danger_color = (255,0,0)
 
120
  strings = [shark_sighted, human_sighted, shark_size_estimate, shark_weight_estimate, danger_level]
121
  relative = max(frame.shape[0],frame.shape[1])
122
  if top_pred['shark_sighted'] and top_pred['dangerous_dist'] and cnt%2 == 0:
 
123
  frame = add_border(frame, color=high_danger_color, thickness=int(relative*0.025))
124
+ frame = add_danger_symbol_from_image(frame, top_pred)
125
  elif top_pred['shark_sighted'] and not top_pred['dangerous_dist'] and cnt%2 == 0:
 
126
  frame = add_border(frame, color=low_danger_color, thickness=int(relative*0.025))
127
+ frame = add_danger_symbol_from_image(frame, top_pred)
128
+ else:
129
+
130
+ frame = add_border(frame, color=(0,0,0), thickness=int(relative*0.025))
131
+
132
  overlay_text_on_image(frame, strings, font=cv2.FONT_HERSHEY_SIMPLEX, font_size=relative*0.0007, font_thickness=1, margin=int(relative*0.05), color=(255, 255, 255))
133
  return frame
134
 
135
 
136
 
137
+ def process_video(input_video,out_fps = 'auto', skip_frames = 7):
138
  cap = cv2.VideoCapture(input_video)
139
 
140
  output_path = "output.mp4"
 
154
  cnt = 0
155
 
156
  while iterating:
157
+ print('overall count ', cnt)
158
+
159
  if (cnt % skip_frames) == 0:
160
  print('starting Frame: ', cnt)
161
  # flip frame vertically
 
170
 
171
 
172
  #frame = cv2.resize(frame, (int(width), int(height)))
173
+ video.write(cv2.cvtColor(frame, cv2.COLOR_RGB2BGR))
174
 
175
  if cnt*skip_frames %2==0 and top_pred['shark_sighted']:
176
+ prediction_frame = cv2.resize(prediction_frame, (int(width), int(height)))
177
  frame =prediction_frame
178
+
179
  if top_pred['shark_sighted']:
180
  frame = draw_cockpit(frame, top_pred,cnt*skip_frames)
181
+ video.write(cv2.cvtColor(frame, cv2.COLOR_RGB2BGR))
182
+
183
+
184
  pred_dashbord = prediction_dashboard(top_pred = top_pred)
185
  #print('sending frame')
186
  print('finalizing frame:',cnt)
187
  print(pred_dashbord.shape)
188
  print(frame.shape)
189
  print(prediction_frame.shape)
190
+ yield frame , None
191
+
192
  cnt += 1
193
  iterating, frame = cap.read()
194
 
195
  video.release()
196
+ yield None, output_path
197
 
198
  with gr.Blocks(theme=theme) as demo:
199
+ with gr.Row().style(equal_height=True):
200
  input_video = gr.Video(label="Input")
 
 
 
 
 
201
 
202
+ original_frames = gr.Image(label="Processed Frame").style( height=650)
203
+ #processed_frames = gr.Image(label="Shark Engine")
204
+ output_video = gr.Video(label="Output Video")
205
+ #dashboard = gr.Image(label="Events")
206
 
207
  with gr.Row():
208
  paths = sorted(pathlib.Path('videos_example/').rglob('*.mp4'))
 
210
  examples = gr.Examples(samples, inputs=input_video)
211
  process_video_btn = gr.Button("Process Video")
212
 
213
+ #process_video_btn.click(process_video, input_video, [processed_frames, original_frames, output_video, dashboard])
214
+ process_video_btn.click(process_video, input_video, [ original_frames, output_video])
215
  demo.queue()
216
  if os.getenv('SYSTEM') == 'spaces':
217
  demo.launch(width='40%',auth=(os.environ.get('SHARK_USERNAME'), os.environ.get('SHARK_PASSWORD')))
inference.py CHANGED
@@ -133,7 +133,7 @@ classes_is_human_id = [i for i, x in enumerate(classes_is_human) if x == 1]
133
  classes_is_unknown_id = [i for i, x in enumerate(classes_is_unknown) if x == 1]
134
 
135
 
136
- if not os.path.exists('model'):
137
  REPO_ID = "SharkSpace/maskformer_model"
138
  FILENAME = "mask2former"
139
  snapshot_download(repo_id=REPO_ID, token= os.environ.get('SHARK_MODEL'),local_dir='model/')
 
133
  classes_is_unknown_id = [i for i, x in enumerate(classes_is_unknown) if x == 1]
134
 
135
 
136
+ if not os.path.exists('model') and not os.getenv('SYSTEM') == 'spaces':
137
  REPO_ID = "SharkSpace/maskformer_model"
138
  FILENAME = "mask2former"
139
  snapshot_download(repo_id=REPO_ID, token= os.environ.get('SHARK_MODEL'),local_dir='model/')
metrics.py CHANGED
@@ -119,6 +119,8 @@ def get_min_distance_shark_person(top_pred, class_sizes = None, dangerous_distan
119
  'dangerous_dist': min_dist < dangerous_distance}
120
 
121
  def _calculate_dist_estimate(bbox1, bbox2, labels, class_sizes = None, measurement = 'feet'):
 
 
122
  class_feet_size_mean = np.array([class_sizes[labels[0]][measurement][0],
123
  class_sizes[labels[1]][measurement][0]]).mean()
124
  box_pixel_size_mean = np.array([np.linalg.norm(bbox1[[0, 1]] - bbox1[[2, 3]]),
 
119
  'dangerous_dist': min_dist < dangerous_distance}
120
 
121
  def _calculate_dist_estimate(bbox1, bbox2, labels, class_sizes = None, measurement = 'feet'):
122
+ if class_sizes[labels[0]] == None or class_sizes[labels[1]] == None:
123
+ return 9999
124
  class_feet_size_mean = np.array([class_sizes[labels[0]][measurement][0],
125
  class_sizes[labels[1]][measurement][0]]).mean()
126
  box_pixel_size_mean = np.array([np.linalg.norm(bbox1[[0, 1]] - bbox1[[2, 3]]),