Iskaj commited on
Commit
2a1a736
1 Parent(s): 0112deb

add segment based decision plotting

Browse files
Files changed (2) hide show
  1. app.py +3 -3
  2. plot.py +63 -0
app.py CHANGED
@@ -5,13 +5,13 @@ import gradio as gr
5
  from config import *
6
  from videomatch import index_hashes_for_video, get_decent_distance, \
7
  get_video_index, compare_videos, get_change_points, get_videomatch_df
8
- from plot import plot_comparison, plot_multi_comparison
9
 
10
  logging.basicConfig()
11
  logging.getLogger().setLevel(logging.INFO)
12
 
13
 
14
- def get_comparison(url, target, MIN_DISTANCE = 4):
15
  """ Function for Gradio to combine all helper functions"""
16
  video_index, hash_vectors = get_video_index(url)
17
  target_index, _ = get_video_index(target)
@@ -31,7 +31,7 @@ def get_auto_comparison(url, target, smoothing_window_size=10, method="CUSUM"):
31
  # fig = plot_comparison(lims, D, I, hash_vectors, MIN_DISTANCE = distance)
32
  df = get_videomatch_df(url, target, min_distance=MIN_DISTANCE, vanilla_df=False)
33
  change_points = get_change_points(df, smoothing_window_size=smoothing_window_size, method=method)
34
- fig = plot_multi_comparison(df, change_points)
35
  return fig
36
 
37
  def get_auto_edit_decision(url, target, smoothing_window_size=10):
 
5
  from config import *
6
  from videomatch import index_hashes_for_video, get_decent_distance, \
7
  get_video_index, compare_videos, get_change_points, get_videomatch_df
8
+ from plot import plot_comparison, plot_multi_comparison, plot_segment_comparison
9
 
10
  logging.basicConfig()
11
  logging.getLogger().setLevel(logging.INFO)
12
 
13
 
14
+ def get_comparison(url, target, MIN_DISTANCE = 4):
15
  """ Function for Gradio to combine all helper functions"""
16
  video_index, hash_vectors = get_video_index(url)
17
  target_index, _ = get_video_index(target)
 
31
  # fig = plot_comparison(lims, D, I, hash_vectors, MIN_DISTANCE = distance)
32
  df = get_videomatch_df(url, target, min_distance=MIN_DISTANCE, vanilla_df=False)
33
  change_points = get_change_points(df, smoothing_window_size=smoothing_window_size, method=method)
34
+ fig = plot_segment_comparison(df, change_points)
35
  return fig
36
 
37
  def get_auto_edit_decision(url, target, smoothing_window_size=10):
plot.py CHANGED
@@ -55,4 +55,67 @@ def plot_multi_comparison(df, change_points):
55
  rand_y_pos = np.random.uniform(low=np.min(df['OFFSET_LIP']), high=np.max(df['OFFSET_LIP']), size=None)
56
  plt.text(x=cp_time, y=rand_y_pos, s=str(np.round(x.confidence, 2)), color='r', rotation=-0.0, fontsize=14)
57
  plt.xticks(rotation=90)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
58
  return fig
 
55
  rand_y_pos = np.random.uniform(low=np.min(df['OFFSET_LIP']), high=np.max(df['OFFSET_LIP']), size=None)
56
  plt.text(x=cp_time, y=rand_y_pos, s=str(np.round(x.confidence, 2)), color='r', rotation=-0.0, fontsize=14)
57
  plt.xticks(rotation=90)
58
+ return fig
59
+
60
+ def change_points_to_segments(df, change_points):
61
+ """ Convert change points from kats detector to segment indicators """
62
+ return [pd.to_datetime(0.0, unit='s').to_datetime64()] + [cp.start_time for cp in change_points] + [pd.to_datetime(df.iloc[-1]['TARGET_S'], unit='s').to_datetime64()]
63
+
64
+ def add_seconds_to_datetime64(datetime64, seconds, subtract=False):
65
+ """Add or substract a number of seconds to a np.datetime64 object """
66
+ s, m = divmod(seconds, 1.0)
67
+ if subtract:
68
+ return datetime64 - np.timedelta64(int(s), 's') - np.timedelta64(int(m * 1000), 'ms')
69
+ return datetime64 + np.timedelta64(int(s), 's') + np.timedelta64(int(m * 1000), 'ms')
70
+
71
+ def plot_segment_comparison(df, change_points):
72
+ """ From the dataframe plot the current set of plots, where the bottom right is most indicative """
73
+ fig, ax_arr = plt.subplots(2, 2, figsize=(12, 4), dpi=100, sharex=True)
74
+ sns.scatterplot(data = df, x='time', y='SOURCE_S', ax=ax_arr[0,0])
75
+ sns.lineplot(data = df, x='time', y='SOURCE_LIP_S', ax=ax_arr[0,1])
76
+
77
+ # Plot change point as lines
78
+ sns.lineplot(data = df, x='time', y='OFFSET_LIP', ax=ax_arr[1,0])
79
+ sns.lineplot(data = df, x='time', y='OFFSET_LIP', ax=ax_arr[1,1])
80
+ timestamps = change_points_to_segments(df, change_points)
81
+
82
+ # To plot the detected segment lines
83
+ for x in timestamps:
84
+ plt.vlines(x=x, ymin=np.min(df['OFFSET_LIP']), ymax=np.max(df['OFFSET_LIP']), colors='black', lw=2)
85
+ rand_y_pos = np.random.uniform(low=np.min(df['OFFSET_LIP']), high=np.max(df['OFFSET_LIP']), size=None)
86
+
87
+ # To get each detected segment and their mean?
88
+ threshold_diff = 1.5 # Average diff threshold
89
+ # threshold = 3.0 # s diff threshold
90
+ for start_time, end_time in zip(timestamps[:-1], timestamps[1:]):
91
+
92
+ add_offset = np.min(df['SOURCE_S'])
93
+
94
+ # Cut out the segment between the segment lines
95
+ segment = df[(df['time'] > start_time) & (df['time'] < end_time)] # Not offset LIP
96
+ segment_no_nan = segment[~np.isnan(segment['OFFSET'])] # Remove NaNs
97
+ seg_mean = np.mean(segment_no_nan['OFFSET'])
98
+
99
+ # Get average difference from mean of the segment to see if it is a "straight line" or not
100
+ # segment_no_nan = segment['OFFSET'][~np.isnan(segment['OFFSET'])] # Remove NaNs
101
+ average_diff = np.mean(np.abs(segment_no_nan['OFFSET'] - seg_mean))
102
+
103
+ # If the time where the segment comes from (origin time) is close to the start_time, it's a "good match", so no editing
104
+ prefix = "GOOD" if average_diff < threshold_diff else "BAD"
105
+ origin_time = add_seconds_to_datetime64(start_time, seg_mean + add_offset)
106
+ # prefix = "BAD"
107
+ # if (start_time < add_seconds_to_datetime64(origin_time, threshold) and (start_time > add_seconds_to_datetime64(origin_time, threshold, subtract=True))):
108
+ # prefix = "GOOD"
109
+
110
+ # Plot green for a confident prediction (straight line), red otherwise
111
+ if prefix == "GOOD":
112
+ plt.text(x=start_time, y=seg_mean, s=str(np.round(average_diff, 1)), color='g', rotation=-0.0, fontsize=14)
113
+ else:
114
+ plt.text(x=start_time, y=seg_mean, s=str(np.round(average_diff, 1)), color='r', rotation=-0.0, fontsize=14)
115
+
116
+ print(f"[{prefix}] DIFF={average_diff:.1f} MEAN={seg_mean:.1f} {start_time} -> {end_time} comes from video X, from {origin_time}")
117
+
118
+
119
+ # Return figure
120
+ plt.xticks(rotation=90)
121
  return fig