gigant commited on
Commit
d0e55ca
1 Parent(s): 79b8121

Create new file

Browse files
Files changed (1) hide show
  1. app.py +114 -0
app.py ADDED
@@ -0,0 +1,114 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ BATCH_SIZE = 64
2
+ DOWNSAMPLE = 24
3
+
4
+ import phash_jax
5
+ import jax.numpy as jnp
6
+ import matplotlib.pyplot as plt
7
+ from PIL import Image
8
+ import statistics
9
+ import gradio
10
+
11
+ def binary_array_to_hex(arr):
12
+ """
13
+ Function to make a hex string out of a binary array.
14
+ """
15
+ bit_string = ''.join(str(b) for b in 1 * arr.flatten())
16
+ width = int(jnp.ceil(len(bit_string) / 4))
17
+ return '{:0>{width}x}'.format(int(bit_string, 2), width=width)
18
+
19
+ def compute_batch_hashes(vid_path):
20
+ kwargs={"width": 64, "height":64}
21
+ vr = VideoReader(vid_path, ctx=cpu(0), **kwargs)
22
+ hashes = []
23
+ h_prev = None
24
+ batch = []
25
+ for i in range(0, len(vr), DOWNSAMPLE * BATCH_SIZE):
26
+ ids = [id for id in range(i, min(i + DOWNSAMPLE * BATCH_SIZE, len(vr)), DOWNSAMPLE)]
27
+ vr.seek(0)
28
+ batch = jnp.array(vr.get_batch(ids).asnumpy())
29
+ batch_h = phash_jax.batch_phash(batch)
30
+ for i in range(len(ids)):
31
+ h = batch_h[i]
32
+ if h_prev == None:
33
+ h_prev=h
34
+ hashes.append({"frame_id":ids[i], "hash": binary_array_to_hex(h), "distance": int(phash_jax.hash_dist(h, h_prev))})
35
+ h_prev = h
36
+ return gradio.update(value=hashes, visible=False)
37
+
38
+ def plot_hash_distance(hashes, threshold):
39
+ fig = plt.figure()
40
+ ids = [h["frame_id"] for h in hashes]
41
+ distances = [h["distance"] for h in hashes]
42
+ plt.plot(ids, distances, ".")
43
+ plt.plot(ids, [threshold]* len(ids), "r-")
44
+ return fig
45
+
46
+ def compute_threshold(hashes):
47
+ min_length = 24 * 3
48
+ ids = [h["frame_id"] for h in hashes]
49
+ distances = [h["distance"] for h in hashes]
50
+ thrs_ = sorted(list(set(distances)),reverse=True)
51
+ best = thrs_[0] - 1
52
+ for threshold in thrs_[1:]:
53
+ durations = []
54
+ i_start=0
55
+ for i, h in enumerate(hashes):
56
+ if h["distance"] > threshold and hashes[i-1]["frame_id"] - hashes[i_start]["frame_id"] > min_length:
57
+ durations.append(hashes[i-1]["frame_id"] - hashes[i_start]["frame_id"])
58
+ i_start=i
59
+ if len(durations) < (len(hashes) * DOWNSAMPLE / 24) / 20:
60
+ best = threshold
61
+ return best
62
+
63
+ def get_slides(vid_path, hashes, threshold):
64
+ min_length = 24 * 1.5
65
+ vr = VideoReader(vid_path, ctx=cpu(0))
66
+ slideshow = []
67
+ i_start = 0
68
+ for i, h in enumerate(hashes):
69
+ if h["distance"] > threshold and hashes[i-1]["frame_id"] - hashes[i_start]["frame_id"] > min_length:
70
+ path=f'{FOLDER_PATH}/{vid_path.split("/")[-1].split(".")[0]}_{i_start}_{i-1}.png'
71
+ Image.fromarray(vr[hashes[i-1]["frame_id"]].asnumpy()).save(path)
72
+ slideshow.append({"slide": path, "start": i_start, "end": i-1})
73
+ i_start=i
74
+ path=f'{FOLDER_PATH}/{vid_path.split("/")[-1].split(".")[0]}_{i_start}_{len(vr)-1}.png'
75
+ Image.fromarray(vr[-1].asnumpy()).save(path)
76
+ slideshow.append({"slide": path, "start": i_start, "end": len(vr)-1})
77
+ return [s["slide"] for s in slideshow]
78
+
79
+ def trigger_plots(f2f_distance_plot, hashes, threshold):
80
+ # if not hist_plot.get_config()["visible"] and len(hashes.get_config()["value"]) > 0 :
81
+ return gradio.update(value=plot_hash_distance(hashes, threshold))
82
+
83
+ def set_visible():
84
+ return gradio.update(visible=True)
85
+
86
+ demo = gradio.Blocks(analytics_enabled=True)
87
+
88
+ with demo:
89
+ with gradio.Row():
90
+ with gradio.Column():
91
+ with gradio.Row():
92
+ vid=gradio.Video(mirror_webcam=False)
93
+ with gradio.Row():
94
+ btn_vid_proc = gradio.Button("Compute hashes")
95
+ with gradio.Row():
96
+ hist_plot = gradio.Plot(label="Frame to frame hash distance histogram", visible=False)
97
+ with gradio.Column():
98
+ hashes = gradio.JSON()
99
+ with gradio.Column(visible=False) as result_row:
100
+ btn_plot = gradio.Button("Plot & compute optimal threshold")
101
+ threshold = gradio.Slider(minimum=1, maximum=30, value=5, label="Threshold")
102
+ f2f_distance_plot = gradio.Plot(label="Frame to frame hash distance")
103
+ btn_slides = gradio.Button("Extract Slides")
104
+ with gradio.Row():
105
+ slideshow = gradio.Gallery(label="Extracted slides")
106
+ slideshow.style(grid=6)
107
+ btn_vid_proc.click(fn=compute_batch_hashes, inputs=[vid], outputs=[hashes])
108
+ hashes.change(fn=set_visible, inputs=[], outputs=[result_row])
109
+ btn_plot.click(fn=compute_threshold, inputs=[hashes], outputs=[threshold])
110
+ btn_plot.click(fn=trigger_plots, inputs=[f2f_distance_plot, hashes, threshold], outputs=[f2f_distance_plot])
111
+ threshold.change(fn=plot_hash_distance, inputs=[hashes, threshold], outputs=f2f_distance_plot)
112
+ btn_slides.click(fn=get_slides, inputs=[vid, hashes, threshold], outputs=[slideshow])
113
+
114
+ demo.launch(debug=True)