Iskaj commited on
Commit
8d6b883
1 Parent(s): a7b0fd6

added autocompare to auto determine good min_distance val

Browse files
Files changed (1) hide show
  1. app.py +36 -13
app.py CHANGED
@@ -16,13 +16,14 @@ import matplotlib.pyplot as plt
16
  import imagehash
17
  from PIL import Image
18
 
19
- import numpy as np
20
  import pandas as pd
21
  import faiss
22
 
23
  import shutil
24
 
25
  FPS = 5
 
26
  MAX_DISTANCE = 30
27
 
28
  video_directory = tempfile.gettempdir()
@@ -104,13 +105,6 @@ def index_hashes_for_video(url, is_file = False):
104
  logging.info(f"Indexed hashes for {index.ntotal} frames to {filename}.index.")
105
  return index
106
 
107
- def get_comparison(url, target, MIN_DISTANCE = 3):
108
- """ Function for Gradio to combine all helper functions"""
109
- video_index, hash_vectors, target_indices = get_video_indices(url, target, MIN_DISTANCE = MIN_DISTANCE)
110
- lims, D, I, hash_vectors = compare_videos(video_index, hash_vectors, target_indices, MIN_DISTANCE = MIN_DISTANCE)
111
- fig = plot_comparison(lims, D, I, hash_vectors, MIN_DISTANCE = MIN_DISTANCE)
112
- return fig
113
-
114
  def get_video_indices(url, target, MIN_DISTANCE = 4):
115
  """" The comparison between the target and the original video will be plotted based
116
  on the matches between the target and the original video over time. The matches are determined
@@ -144,8 +138,18 @@ def compare_videos(video_index, hash_vectors, target_indices, MIN_DISTANCE = 3):
144
  lims, D, I = target_indices[0].range_search(hash_vectors, MIN_DISTANCE)
145
  return lims, D, I, hash_vectors
146
 
147
- def plot_distances(target_indices, hash_vectors, MIN_DISTANCE, MAX_DISTANCE):
148
- pass
 
 
 
 
 
 
 
 
 
 
149
 
150
  def plot_comparison(lims, D, I, hash_vectors, MIN_DISTANCE = 3):
151
  sns.set_theme()
@@ -179,7 +183,22 @@ def plot_comparison(lims, D, I, hash_vectors, MIN_DISTANCE = 3):
179
  return fig
180
 
181
  logging.basicConfig()
182
- logging.getLogger().setLevel(logging.DEBUG)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
183
 
184
  video_urls = ["https://www.dropbox.com/s/8c89a9aba0w8gjg/Ploumen.mp4?dl=1",
185
  "https://www.dropbox.com/s/rzmicviu1fe740t/Bram%20van%20Ojik%20krijgt%20reprimande.mp4?dl=1",
@@ -194,14 +213,18 @@ compare_iface = gr.Interface(fn=get_comparison,
194
  inputs=["text", "text", gr.Slider(2, 30, 4, step=2)], outputs="plot",
195
  examples=[[x, video_urls[-1]] for x in video_urls[:-1]])
196
 
197
- iface = gr.TabbedInterface([index_iface, compare_iface], ["Index", "Compare"])
 
 
 
 
198
 
199
  if __name__ == "__main__":
200
  import matplotlib
201
  matplotlib.use('SVG') # To be able to plot in gradio
202
 
203
  logging.basicConfig()
204
- logging.getLogger().setLevel(logging.DEBUG)
205
 
206
  iface.launch()
207
  #iface.launch(auth=("test", "test"), share=True, debug=True)
 
16
  import imagehash
17
  from PIL import Image
18
 
19
+ import numpy as np
20
  import pandas as pd
21
  import faiss
22
 
23
  import shutil
24
 
25
  FPS = 5
26
+ MIN_DISTANCE = 4
27
  MAX_DISTANCE = 30
28
 
29
  video_directory = tempfile.gettempdir()
 
105
  logging.info(f"Indexed hashes for {index.ntotal} frames to {filename}.index.")
106
  return index
107
 
 
 
 
 
 
 
 
108
  def get_video_indices(url, target, MIN_DISTANCE = 4):
109
  """" The comparison between the target and the original video will be plotted based
110
  on the matches between the target and the original video over time. The matches are determined
 
138
  lims, D, I = target_indices[0].range_search(hash_vectors, MIN_DISTANCE)
139
  return lims, D, I, hash_vectors
140
 
141
+ def get_decent_distance(url, target, MIN_DISTANCE, MAX_DISTANCE):
142
+ """ To get a decent heurstic for a base distance check every distance from MIN_DISTANCE to MAX_DISTANCE
143
+ until the number of matches found is equal to or higher than the number of frames in the source video"""
144
+ for distance in np.arange(start = MIN_DISTANCE - 2, stop = MAX_DISTANCE + 2, step = 2, dtype=int):
145
+ distance = int(distance)
146
+ video_index, hash_vectors, target_indices = get_video_indices(url, target, MIN_DISTANCE = distance)
147
+ lims, D, I, hash_vectors = compare_videos(video_index, hash_vectors, target_indices, MIN_DISTANCE = distance)
148
+ nr_source_frames = video_index.ntotal
149
+ nr_matches = len(D)
150
+ logging.info(f"{(nr_matches/nr_source_frames) * 100.0:.1f}% of frames have a match for distance '{distance}' ({nr_matches} matches for {nr_source_frames} frames)")
151
+ if nr_matches >= nr_source_frames:
152
+ return distance
153
 
154
  def plot_comparison(lims, D, I, hash_vectors, MIN_DISTANCE = 3):
155
  sns.set_theme()
 
183
  return fig
184
 
185
  logging.basicConfig()
186
+ logging.getLogger().setLevel(logging.INFO)
187
+
188
+ def get_comparison(url, target, MIN_DISTANCE = 4):
189
+ """ Function for Gradio to combine all helper functions"""
190
+ video_index, hash_vectors, target_indices = get_video_indices(url, target, MIN_DISTANCE = MIN_DISTANCE)
191
+ lims, D, I, hash_vectors = compare_videos(video_index, hash_vectors, target_indices, MIN_DISTANCE = MIN_DISTANCE)
192
+ fig = plot_comparison(lims, D, I, hash_vectors, MIN_DISTANCE = MIN_DISTANCE)
193
+ return fig
194
+
195
+ def get_auto_comparison(url, target, MIN_DISTANCE = MIN_DISTANCE):
196
+ """ Function for Gradio to combine all helper functions"""
197
+ distance = get_decent_distance(url, target, MIN_DISTANCE, MAX_DISTANCE)
198
+ video_index, hash_vectors, target_indices = get_video_indices(url, target, MIN_DISTANCE = distance)
199
+ lims, D, I, hash_vectors = compare_videos(video_index, hash_vectors, target_indices, MIN_DISTANCE = distance)
200
+ fig = plot_comparison(lims, D, I, hash_vectors, MIN_DISTANCE = distance)
201
+ return fig
202
 
203
  video_urls = ["https://www.dropbox.com/s/8c89a9aba0w8gjg/Ploumen.mp4?dl=1",
204
  "https://www.dropbox.com/s/rzmicviu1fe740t/Bram%20van%20Ojik%20krijgt%20reprimande.mp4?dl=1",
 
213
  inputs=["text", "text", gr.Slider(2, 30, 4, step=2)], outputs="plot",
214
  examples=[[x, video_urls[-1]] for x in video_urls[:-1]])
215
 
216
+ auto_compare_iface = gr.Interface(fn=get_auto_comparison,
217
+ inputs=["text", "text"], outputs="plot",
218
+ examples=[[x, video_urls[-1]] for x in video_urls[:-1]])
219
+
220
+ iface = gr.TabbedInterface([index_iface, compare_iface, auto_compare_iface], ["Index", "Compare", "AutoCompare"])
221
 
222
  if __name__ == "__main__":
223
  import matplotlib
224
  matplotlib.use('SVG') # To be able to plot in gradio
225
 
226
  logging.basicConfig()
227
+ logging.getLogger().setLevel(logging.INFO)
228
 
229
  iface.launch()
230
  #iface.launch(auth=("test", "test"), share=True, debug=True)