Iskaj commited on
Commit
1991773
1 Parent(s): 9807395

changed to work with apb files, plotting seperated from decision

Browse files
Files changed (5) hide show
  1. app.py +74 -29
  2. clip_data.ipynb +3 -3
  3. config.py +2 -1
  4. plot.py +27 -18
  5. videomatch.py +7 -1
app.py CHANGED
@@ -1,16 +1,31 @@
1
  import logging
 
 
 
2
 
3
  import gradio as gr
 
4
 
5
  from config import *
6
  from videomatch import index_hashes_for_video, get_decent_distance, \
7
- get_video_index, compare_videos, get_change_points, get_videomatch_df
 
8
  from plot import plot_comparison, plot_multi_comparison, plot_segment_comparison
9
 
10
  logging.basicConfig()
11
  logging.getLogger().setLevel(logging.INFO)
12
 
13
-
 
 
 
 
 
 
 
 
 
 
14
  def get_comparison(url, target, MIN_DISTANCE = 4):
15
  """ Function for Gradio to combine all helper functions"""
16
  video_index, hash_vectors = get_video_index(url)
@@ -19,25 +34,53 @@ def get_comparison(url, target, MIN_DISTANCE = 4):
19
  fig = plot_comparison(lims, D, I, hash_vectors, MIN_DISTANCE = MIN_DISTANCE)
20
  return fig
21
 
22
- def get_auto_comparison(url, target, smoothing_window_size=10, metric="OFFSET_LIP"):
23
- """ Function for Gradio to combine all helper functions"""
24
  source_index, source_hash_vectors = get_video_index(url)
25
  target_index, _ = get_video_index(target)
 
 
26
  distance = get_decent_distance(source_index, source_hash_vectors, target_index, MIN_DISTANCE, MAX_DISTANCE)
27
  if distance == None:
28
- return _, []
29
- # raise gr.Error("No matches found!")
30
-
31
- # For each video do...
32
- for i in range(0, 1):
33
  lims, D, I, hash_vectors = compare_videos(source_hash_vectors, target_index, MIN_DISTANCE = distance)
 
 
34
  df = get_videomatch_df(lims, D, I, hash_vectors, distance)
35
- change_points = get_change_points(df, smoothing_window_size=smoothing_window_size, metric=metric, method="ROBUST")
36
- fig, segment_decision = plot_segment_comparison(df, change_points, video_id="Placeholder_Video_ID")
37
- return fig, segment_decision
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
38
 
 
 
 
39
 
40
- video_urls = ["https://www.dropbox.com/s/8c89a9aba0w8gjg/Ploumen.mp4?dl=1",
 
 
 
 
 
41
  "https://www.dropbox.com/s/rzmicviu1fe740t/Bram%20van%20Ojik%20krijgt%20reprimande.mp4?dl=1",
42
  "https://www.dropbox.com/s/wcot34ldmb84071/Baudet%20ontmaskert%20Omtzigt_%20u%20bent%20door%20de%20mand%20gevallen%21.mp4?dl=1",
43
  "https://drive.google.com/uc?id=1XW0niHR1k09vPNv1cp6NvdGXe7FHJc1D&export=download",
@@ -46,22 +89,24 @@ video_urls = ["https://www.dropbox.com/s/8c89a9aba0w8gjg/Ploumen.mp4?dl=1",
46
  index_iface = gr.Interface(fn=lambda url: index_hashes_for_video(url).ntotal,
47
  inputs="text",
48
  outputs="text",
49
- examples=video_urls, cache_examples=True)
50
-
51
- compare_iface = gr.Interface(fn=get_comparison,
52
- inputs=["text", "text", gr.Slider(2, 30, 4, step=2)],
53
- outputs="plot",
54
- examples=[[x, video_urls[-1]] for x in video_urls[:-1]])
55
-
56
- auto_compare_iface = gr.Interface(fn=get_auto_comparison,
57
- inputs=["text",
58
- "text",
59
- gr.Slider(2, 50, 10, step=1),
60
- gr.Dropdown(choices=["OFFSET_LIP", "ROLL_OFFSET_MODE"], value="OFFSET_LIP")],
61
- outputs=["plot", "json"],
62
- examples=[[x, video_urls[-1]] for x in video_urls[:-1]])
63
-
64
- iface = gr.TabbedInterface([auto_compare_iface, compare_iface, index_iface,], ["AutoCompare", "Compare", "Index"])
 
 
65
 
66
  if __name__ == "__main__":
67
  import matplotlib
 
1
  import logging
2
+ import os
3
+ import json
4
+ import matplotlib.pyplot as plt
5
 
6
  import gradio as gr
7
+ from faiss import read_index_binary, write_index_binary
8
 
9
  from config import *
10
  from videomatch import index_hashes_for_video, get_decent_distance, \
11
+ get_video_index, compare_videos, get_change_points, get_videomatch_df, \
12
+ get_target_urls
13
  from plot import plot_comparison, plot_multi_comparison, plot_segment_comparison
14
 
15
  logging.basicConfig()
16
  logging.getLogger().setLevel(logging.INFO)
17
 
18
+ def transfer_data_indices_to_temp(temp_path = VIDEO_DIRECTORY, data_path='./data'):
19
+ """ The binary indices created from the .json file are not stored in the temporary directory
20
+ This function will load these indices and write them to the temporary directory.
21
+ Doing it this way reserves the way to link dynamically downloaded files and the static
22
+ files are the same """
23
+ index_files = os.listdir(data_path)
24
+ for index_file in index_files:
25
+ # Read from static location and write to temp storage
26
+ binary_index = read_index_binary(os.path.join(data_path, index_file))
27
+ write_index_binary(binary_index, f'{temp_path}/{index_file}')
28
+
29
  def get_comparison(url, target, MIN_DISTANCE = 4):
30
  """ Function for Gradio to combine all helper functions"""
31
  video_index, hash_vectors = get_video_index(url)
 
34
  fig = plot_comparison(lims, D, I, hash_vectors, MIN_DISTANCE = MIN_DISTANCE)
35
  return fig
36
 
37
+ def compare(url, target):
38
+ # Get source and target indices
39
  source_index, source_hash_vectors = get_video_index(url)
40
  target_index, _ = get_video_index(target)
41
+
42
+ # Get decent distance by comparing url index with the target hash vectors + target index
43
  distance = get_decent_distance(source_index, source_hash_vectors, target_index, MIN_DISTANCE, MAX_DISTANCE)
44
  if distance == None:
45
+ logging.info(f"No matches found between {url} and {target}!")
46
+ return plt.figure(), []
47
+ else:
48
+ # Compare videos with heuristic distance
 
49
  lims, D, I, hash_vectors = compare_videos(source_hash_vectors, target_index, MIN_DISTANCE = distance)
50
+
51
+ # Get dataframe holding all information
52
  df = get_videomatch_df(lims, D, I, hash_vectors, distance)
53
+
54
+ # Determine change point using ROBUST method based on column ROLL_OFFSET_MODE
55
+ change_points = get_change_points(df, metric="ROLL_OFFSET_MODE", method="ROBUST")
56
+
57
+ # Plot and get figure and .json-style segment decision
58
+ fig, segment_decision = plot_segment_comparison(df, change_points, video_id=target)
59
+ return fig, segment_decision
60
+
61
+ def multiple_comparison(url, return_figure=False):
62
+ targets = get_target_urls()
63
+
64
+ # Figure and decision (list of dicts) storage
65
+ figures, decisions = [], []
66
+ for target in targets:
67
+ # Make comparison
68
+ fig, segment_decision = compare(url, target)
69
+
70
+ # Add decisions to global decision list
71
+ decisions.extend(segment_decision)
72
+ figures.append(fig)
73
 
74
+ if return_figure:
75
+ return figures
76
+ return decisions
77
 
78
+ def plot_multiple_comparison(url):
79
+ return multiple_comparison(url, return_figure=True)
80
+
81
+ transfer_data_indices_to_temp() # NOTE: Only works after doing 'git lfs pull' to actually obtain the .index files
82
+ example_video_urls = ["https://drive.google.com/uc?id=1Y1-ypXOvLrp1x0cjAe_hMobCEdA0UbEo&export=download",
83
+ "https://www.dropbox.com/s/8c89a9aba0w8gjg/Ploumen.mp4?dl=1",
84
  "https://www.dropbox.com/s/rzmicviu1fe740t/Bram%20van%20Ojik%20krijgt%20reprimande.mp4?dl=1",
85
  "https://www.dropbox.com/s/wcot34ldmb84071/Baudet%20ontmaskert%20Omtzigt_%20u%20bent%20door%20de%20mand%20gevallen%21.mp4?dl=1",
86
  "https://drive.google.com/uc?id=1XW0niHR1k09vPNv1cp6NvdGXe7FHJc1D&export=download",
 
89
  index_iface = gr.Interface(fn=lambda url: index_hashes_for_video(url).ntotal,
90
  inputs="text",
91
  outputs="text",
92
+ examples=example_video_urls, cache_examples=True)
93
+
94
+ # compare_iface = gr.Interface(fn=get_comparison,
95
+ # inputs=["text", "text", gr.Slider(2, 30, 4, step=2)],
96
+ # outputs="plot",
97
+ # examples=[[x, example_video_urls[-1]] for x in example_video_urls[:-1]])
98
+
99
+ plot_compare_iface = gr.Interface(fn=plot_multiple_comparison,
100
+ inputs=["text"],
101
+ outputs=[gr.Plot() for _ in range(len(get_target_urls()))],
102
+ examples=example_video_urls)
103
+
104
+ auto_compare_iface = gr.Interface(fn=multiple_comparison,
105
+ inputs=["text"],
106
+ outputs=["json"],
107
+ examples=example_video_urls)
108
+
109
+ iface = gr.TabbedInterface([auto_compare_iface, plot_compare_iface, index_iface], ["AutoCompare", "PlotAutoCompare", "Index"])
110
 
111
  if __name__ == "__main__":
112
  import matplotlib
clip_data.ipynb CHANGED
@@ -395,7 +395,7 @@
395
  ],
396
  "metadata": {
397
  "kernelspec": {
398
- "display_name": "Python 3.9.7 ('base')",
399
  "language": "python",
400
  "name": "python3"
401
  },
@@ -409,12 +409,12 @@
409
  "name": "python",
410
  "nbconvert_exporter": "python",
411
  "pygments_lexer": "ipython3",
412
- "version": "3.9.7"
413
  },
414
  "orig_nbformat": 4,
415
  "vscode": {
416
  "interpreter": {
417
- "hash": "35ac539f20c4edc7c4b10c8a5969be22a35cbd7bf12b66c83932160e8a573333"
418
  }
419
  }
420
  },
 
395
  ],
396
  "metadata": {
397
  "kernelspec": {
398
+ "display_name": "Python 3.9.13 64-bit",
399
  "language": "python",
400
  "name": "python3"
401
  },
 
409
  "name": "python",
410
  "nbconvert_exporter": "python",
411
  "pygments_lexer": "ipython3",
412
+ "version": "3.9.13"
413
  },
414
  "orig_nbformat": 4,
415
  "vscode": {
416
  "interpreter": {
417
+ "hash": "397704579725e15f5c7cb49fe5f0341eb7531c82d19f2c29d197e8b64ab5776b"
418
  }
419
  }
420
  },
config.py CHANGED
@@ -1,8 +1,9 @@
1
  import tempfile
2
 
3
  VIDEO_DIRECTORY = tempfile.gettempdir()
 
4
 
5
  FPS = 5
6
  MIN_DISTANCE = 4
7
- MAX_DISTANCE = 30
8
  ROLLING_WINDOW_SIZE = 10
 
1
  import tempfile
2
 
3
  VIDEO_DIRECTORY = tempfile.gettempdir()
4
+ # VIDEO_DIRECTORY = './data/'
5
 
6
  FPS = 5
7
  MIN_DISTANCE = 4
8
+ MAX_DISTANCE = 30 # Used to be 30
9
  ROLLING_WINDOW_SIZE = 10
plot.py CHANGED
@@ -69,33 +69,40 @@ def add_seconds_to_datetime64(datetime64, seconds, subtract=False):
69
  return datetime64 - np.timedelta64(int(s), 's') - np.timedelta64(int(m * 1000), 'ms')
70
  return datetime64 + np.timedelta64(int(s), 's') + np.timedelta64(int(m * 1000), 'ms')
71
 
72
- def plot_segment_comparison(df, change_points, video_id="Placeholder_Video_ID"):
73
- """ From the dataframe plot the current set of plots, where the bottom right is most indicative """
74
- fig, ax_arr = plt.subplots(3, 1, figsize=(16, 6), dpi=100, sharex=True)
 
75
 
76
- # Plot original datapoints without linear interpolation, offset by target video time
77
- sns.scatterplot(data = df, x='time', y='OFFSET', ax=ax_arr[0], label="OFFSET", alpha=0.5)
 
 
 
 
 
 
 
78
 
79
- # Plot linearly interpolated values
80
- sns.lineplot(data = df, x='time', y='OFFSET_LIP', ax=ax_arr[1], label="OFFSET_LIP", color='orange')
81
 
82
- # Plot our target metric wherer
83
  metric = 'ROLL_OFFSET_MODE' # 'OFFSET'
84
- sns.scatterplot(data = df, x='time', y=metric, ax=ax_arr[1], label=metric, alpha=0.5)
 
85
 
86
- # Plot deteected change points as lines which will indicate the segments
87
- sns.scatterplot(data = df, x='time', y=metric, ax=ax_arr[2], label=metric, s=20)
88
  timestamps = change_points_to_segments(df, change_points)
89
-
 
 
90
  # To store "decisions" about segments
91
  segment_decisions = []
92
  seg_i = 0
93
 
94
- # To plot the detected segment lines
95
- for x in timestamps:
96
- plt.vlines(x=x, ymin=np.min(df[metric]), ymax=np.max(df[metric]), colors='black', lw=2, alpha=0.5)
97
-
98
- threshold_diff = 1.5 # Average segment difference threshold for plotting
99
  for start_time, end_time in zip(timestamps[:-1], timestamps[1:]):
100
 
101
  # Time to add to each origin time to get the correct time back since it is offset by add_offset
@@ -149,4 +156,6 @@ def plot_segment_comparison(df, change_points, video_id="Placeholder_Video_ID"):
149
 
150
  # Return figure
151
  plt.xticks(rotation=90)
152
- return fig, segment_decisions
 
 
 
69
  return datetime64 - np.timedelta64(int(s), 's') - np.timedelta64(int(m * 1000), 'ms')
70
  return datetime64 + np.timedelta64(int(s), 's') + np.timedelta64(int(m * 1000), 'ms')
71
 
72
+ def plot_segment_comparison(df, change_points, video_id="Placeholder_Video_ID", threshold_diff = 1.5):
73
+ """ Based on the dataframe and detected change points do two things:
74
+ 1. Make a decision on where each segment belongs in time and return that info as a list of dicts
75
+ 2. Plot how this decision got made as an informative plot
76
 
77
+ args:
78
+ - df: dataframe
79
+ - change_points: detected points in time where the average metric value changes
80
+ - video_id: the unique identifier for the video currently being compared
81
+ - threshold_diff: to plot which segments are likely bad matches
82
+ """
83
+ fig, ax_arr = plt.subplots(4, 1, figsize=(16, 6), dpi=300, sharex=True)
84
+ ax_arr[0].set_title(video_id)
85
+ sns.scatterplot(data = df, x='time', y='SOURCE_S', ax=ax_arr[0], label="SOURCE_S", color='blue', alpha=1.0)
86
 
87
+ # Plot original datapoints without linear interpolation, offset by target video time
88
+ sns.scatterplot(data = df, x='time', y='OFFSET', ax=ax_arr[1], label="OFFSET", color='orange', alpha=1.0)
89
 
90
+ # Plot linearly interpolated values next to metric vales
91
  metric = 'ROLL_OFFSET_MODE' # 'OFFSET'
92
+ sns.lineplot(data = df, x='time', y='OFFSET_LIP', ax=ax_arr[2], label="OFFSET_LIP", color='orange')
93
+ sns.scatterplot(data = df, x='time', y=metric, ax=ax_arr[2], label=metric, alpha=0.5)
94
 
95
+ # Plot detected change points as lines which will indicate the segments
96
+ sns.scatterplot(data = df, x='time', y=metric, ax=ax_arr[3], label=metric, s=20)
97
  timestamps = change_points_to_segments(df, change_points)
98
+ for x in timestamps:
99
+ plt.vlines(x=x, ymin=np.min(df[metric]), ymax=np.max(df[metric]), colors='black', lw=2, alpha=0.5)
100
+
101
  # To store "decisions" about segments
102
  segment_decisions = []
103
  seg_i = 0
104
 
105
+ # Average segment difference threshold for plotting
 
 
 
 
106
  for start_time, end_time in zip(timestamps[:-1], timestamps[1:]):
107
 
108
  # Time to add to each origin time to get the correct time back since it is offset by add_offset
 
156
 
157
  # Return figure
158
  plt.xticks(rotation=90)
159
+ return fig, segment_decisions
160
+
161
+
videomatch.py CHANGED
@@ -1,6 +1,6 @@
1
  import os
2
  import logging
3
-
4
  import faiss
5
 
6
  from kats.detectors.cusum_detection import CUSUMDetector
@@ -15,6 +15,12 @@ import pandas as pd
15
  from videohash import compute_hashes, filepath_from_url
16
  from config import FPS, MIN_DISTANCE, MAX_DISTANCE, ROLLING_WINDOW_SIZE
17
 
 
 
 
 
 
 
18
  def index_hashes_for_video(url: str) -> faiss.IndexBinaryIVF:
19
  """ Compute hashes of a video and index the video using faiss indices and return the index. """
20
  filepath = filepath_from_url(url)
 
1
  import os
2
  import logging
3
+ import json
4
  import faiss
5
 
6
  from kats.detectors.cusum_detection import CUSUMDetector
 
15
  from videohash import compute_hashes, filepath_from_url
16
  from config import FPS, MIN_DISTANCE, MAX_DISTANCE, ROLLING_WINDOW_SIZE
17
 
18
+ def get_target_urls(json_file='apb2022.json'):
19
+ """ Obtain target urls for the target videos of a json file containing .mp4 files """
20
+ with open('apb2022.json', "r") as json_file:
21
+ target_videos = json.load(json_file)
22
+ return [video['mp4'] for video in target_videos]
23
+
24
  def index_hashes_for_video(url: str) -> faiss.IndexBinaryIVF:
25
  """ Compute hashes of a video and index the video using faiss indices and return the index. """
26
  filepath = filepath_from_url(url)