Iskaj commited on
Commit
3b3290d
1 Parent(s): 7f2c8f8

added comments, added data aggregation for decision making

Browse files
Files changed (2) hide show
  1. Matching Exploration.ipynb +0 -0
  2. app.py +51 -17
Matching Exploration.ipynb CHANGED
The diff for this file is too large to render. See raw diff
 
app.py CHANGED
@@ -4,15 +4,20 @@ import logging
4
  import os
5
  import hashlib
6
  import datetime
 
7
 
8
  import pandas
9
  import gradio as gr
10
  from moviepy.editor import VideoFileClip
11
 
 
 
 
12
  import imagehash
13
  from PIL import Image
14
 
15
  import numpy as np
 
16
  import faiss
17
 
18
  FPS = 5
@@ -26,6 +31,8 @@ def download_video_from_url(url):
26
  with (urllib.request.urlopen(url)) as f, open(filename, 'wb') as fileout:
27
  fileout.write(f.read())
28
  logging.info(f"Downloaded video from {url} to {filename}.")
 
 
29
  return filename
30
 
31
  def change_ffmpeg_fps(clip, fps=FPS):
@@ -51,13 +58,19 @@ def binary_array_to_uint8s(arr):
51
 
52
  def compute_hashes(clip, fps=FPS):
53
  for index, frame in enumerate(change_ffmpeg_fps(clip, fps).iter_frames()):
 
 
 
54
  hashed = np.array(binary_array_to_uint8s(compute_hash(frame).hash), dtype='uint8')
55
  yield {"frame": 1+index*fps, "hash": hashed}
56
 
57
  def index_hashes_for_video(url):
58
  filename = download_video_from_url(url)
59
  if os.path.exists(f'{filename}.index'):
60
- return faiss.read_index_binary(f'{filename}.index')
 
 
 
61
 
62
  hash_vectors = np.array([x['hash'] for x in compute_hashes(VideoFileClip(filename))])
63
  logging.info(f"Computed hashes for {hash_vectors.shape} frames.")
@@ -87,33 +100,54 @@ def compare_videos(url, target, MIN_DISTANCE = 3):
87
  """
88
  # TODO: Fix crash if no matches are found
89
 
 
90
  video_index = index_hashes_for_video(url)
91
- target_indices = [index_hashes_for_video(x) for x in [target]]
92
-
93
  video_index.make_direct_map() # Make sure the index is indexable
94
  hash_vectors = np.array([video_index.reconstruct(i) for i in range(video_index.ntotal)]) # Retrieve original indices
95
 
 
 
 
96
  # The results are returned as a triplet of 1D arrays
97
  # lims, D, I, where result for query i is in I[lims[i]:lims[i+1]]
98
  # (indices of neighbors), D[lims[i]:lims[i+1]] (distances).
99
-
100
  lims, D, I = target_indices[0].range_search(hash_vectors, MIN_DISTANCE)
101
 
102
-
 
 
 
103
 
104
  x = [(lims[i+1]-lims[i]) * [i] for i in range(hash_vectors.shape[0])]
105
- x = [datetime.datetime(1970, 1, 1, 0, 0) + datetime.timedelta(seconds=i/FPS) for j in x for i in j]
106
- y = [datetime.datetime(1970, 1, 1, 0, 0) + datetime.timedelta(seconds=i/FPS) for i in I]
107
-
108
- import matplotlib.pyplot as plt
109
-
110
- ax = plt.figure()
111
- if x and y:
112
- plt.scatter(x, y, s=2*(1-D/MIN_DISTANCE), alpha=1-D/MIN_DISTANCE)
113
- plt.xlabel('Time in source video (seconds)')
114
- plt.ylabel('Time in target video (seconds)')
115
- plt.show()
116
- return ax
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
117
 
118
  video_urls = ["https://www.dropbox.com/s/8c89a9aba0w8gjg/Ploumen.mp4?dl=1",
119
  "https://www.dropbox.com/s/rzmicviu1fe740t/Bram%20van%20Ojik%20krijgt%20reprimande.mp4?dl=1",
 
4
  import os
5
  import hashlib
6
  import datetime
7
+ import time
8
 
9
  import pandas
10
  import gradio as gr
11
  from moviepy.editor import VideoFileClip
12
 
13
+ import seaborn as sns
14
+ import matplotlib.pyplot as plt
15
+
16
  import imagehash
17
  from PIL import Image
18
 
19
  import numpy as np
20
+ import pandas as pd
21
  import faiss
22
 
23
  FPS = 5
 
31
  with (urllib.request.urlopen(url)) as f, open(filename, 'wb') as fileout:
32
  fileout.write(f.read())
33
  logging.info(f"Downloaded video from {url} to {filename}.")
34
+ else:
35
+ logging.info(f"Skipping downloading from {url} because {filename} already exists.")
36
  return filename
37
 
38
  def change_ffmpeg_fps(clip, fps=FPS):
 
58
 
59
  def compute_hashes(clip, fps=FPS):
60
  for index, frame in enumerate(change_ffmpeg_fps(clip, fps).iter_frames()):
61
+ # Each frame is a triplet of size (height, width, 3) of the video since it is RGB
62
+ # The hash itself is of size (hash_size, hash_size)
63
+ # The uint8 version of the hash is of size (hash_size * highfreq_factor,) and represents the hash
64
  hashed = np.array(binary_array_to_uint8s(compute_hash(frame).hash), dtype='uint8')
65
  yield {"frame": 1+index*fps, "hash": hashed}
66
 
67
  def index_hashes_for_video(url):
68
  filename = download_video_from_url(url)
69
  if os.path.exists(f'{filename}.index'):
70
+ logging.info(f"Loading indexed hashes from {filename}.index")
71
+ binary_index = faiss.read_index_binary(f'{filename}.index')
72
+ logging.info(f"Index {filename}.index has in total {binary_index.ntotal} frames")
73
+ return binary_index
74
 
75
  hash_vectors = np.array([x['hash'] for x in compute_hashes(VideoFileClip(filename))])
76
  logging.info(f"Computed hashes for {hash_vectors.shape} frames.")
 
100
  """
101
  # TODO: Fix crash if no matches are found
102
 
103
+ # Url (short video)
104
  video_index = index_hashes_for_video(url)
 
 
105
  video_index.make_direct_map() # Make sure the index is indexable
106
  hash_vectors = np.array([video_index.reconstruct(i) for i in range(video_index.ntotal)]) # Retrieve original indices
107
 
108
+ # Target video (long video)
109
+ target_indices = [index_hashes_for_video(x) for x in [target]]
110
+
111
  # The results are returned as a triplet of 1D arrays
112
  # lims, D, I, where result for query i is in I[lims[i]:lims[i+1]]
113
  # (indices of neighbors), D[lims[i]:lims[i+1]] (distances).
 
114
  lims, D, I = target_indices[0].range_search(hash_vectors, MIN_DISTANCE)
115
 
116
+ return plot_comparison(lims, D, I, hash_vectors, MIN_DISTANCE = MIN_DISTANCE)
117
+
118
+ def plot_comparison(lims, D, I, hash_vectors, MIN_DISTANCE = 3):
119
+ sns.set_theme()
120
 
121
  x = [(lims[i+1]-lims[i]) * [i] for i in range(hash_vectors.shape[0])]
122
+ x = [i/FPS for j in x for i in j]
123
+ y = [i/FPS for i in I]
124
+
125
+ # Create figure and dataframe to plot with sns
126
+ fig = plt.figure()
127
+ # plt.tight_layout()
128
+ df = pd.DataFrame(zip(x, y), columns = ['X', 'Y'])
129
+ g = sns.scatterplot(data=df, x='X', y='Y', s=2*(1-D/(MIN_DISTANCE+1)), alpha=1-D/MIN_DISTANCE)
130
+
131
+ # Set x-labels to be more readable
132
+ x_locs, x_labels = plt.xticks() # Get original locations and labels for x ticks
133
+ x_labels = [time.strftime('%H:%M:%S', time.gmtime(x)) for x in x_locs]
134
+ plt.xticks(x_locs, x_labels)
135
+ plt.xticks(rotation=90)
136
+ plt.xlabel('Time in source video (H:M:S)')
137
+ plt.xlim(0, None)
138
+
139
+ # Set y-labels to be more readable
140
+ y_locs, y_labels = plt.yticks() # Get original locations and labels for x ticks
141
+ y_labels = [time.strftime('%H:%M:%S', time.gmtime(y)) for y in y_locs]
142
+ plt.yticks(y_locs, y_labels)
143
+ plt.ylabel('Time in target video (H:M:S)')
144
+
145
+ # Adjust padding to fit gradio
146
+ plt.subplots_adjust(bottom=0.25, left=0.20)
147
+ return fig
148
+
149
+ logging.basicConfig()
150
+ logging.getLogger().setLevel(logging.DEBUG)
151
 
152
  video_urls = ["https://www.dropbox.com/s/8c89a9aba0w8gjg/Ploumen.mp4?dl=1",
153
  "https://www.dropbox.com/s/rzmicviu1fe740t/Bram%20van%20Ojik%20krijgt%20reprimande.mp4?dl=1",