BATCH_SIZE = 64 DOWNSAMPLE = 24 FOLDER_PATH = "./tmp" import phash_jax import jax.numpy as jnp import matplotlib.pyplot as plt from PIL import Image import statistics import gradio def binary_array_to_hex(arr): """ Function to make a hex string out of a binary array. """ bit_string = ''.join(str(b) for b in 1 * arr.flatten()) width = int(jnp.ceil(len(bit_string) / 4)) return '{:0>{width}x}'.format(int(bit_string, 2), width=width) def compute_batch_hashes(vid_path): kwargs={"width": 64, "height":64} vr = VideoReader(vid_path, ctx=cpu(0), **kwargs) hashes = [] h_prev = None batch = [] for i in range(0, len(vr), DOWNSAMPLE * BATCH_SIZE): print(f"batch_{i}") ids = [id for id in range(i, min(i + DOWNSAMPLE * BATCH_SIZE, len(vr)), DOWNSAMPLE)] vr.seek(0) batch = jnp.array(vr.get_batch(ids).asnumpy()) batch_h = phash_jax.batch_phash(batch) for i in range(len(ids)): h = batch_h[i] if h_prev == None: h_prev=h hashes.append({"frame_id":ids[i], "hash": binary_array_to_hex(h), "distance": int(phash_jax.hash_dist(h, h_prev))}) h_prev = h return gradio.update(value=hashes, visible=False) def plot_hash_distance(hashes, threshold): fig = plt.figure() ids = [h["frame_id"] for h in hashes] distances = [h["distance"] for h in hashes] plt.plot(ids, distances, ".") plt.plot(ids, [threshold]* len(ids), "r-") return fig def compute_threshold(hashes): min_length = 24 * 3 ids = [h["frame_id"] for h in hashes] distances = [h["distance"] for h in hashes] thrs_ = sorted(list(set(distances)),reverse=True) best = thrs_[0] - 1 for threshold in thrs_[1:]: durations = [] i_start=0 for i, h in enumerate(hashes): if h["distance"] > threshold and hashes[i-1]["frame_id"] - hashes[i_start]["frame_id"] > min_length: durations.append(hashes[i-1]["frame_id"] - hashes[i_start]["frame_id"]) i_start=i if len(durations) < (len(hashes) * DOWNSAMPLE / 24) / 20: best = threshold return best def get_slides(vid_path, hashes, threshold): min_length = 24 * 1.5 vr = VideoReader(vid_path, ctx=cpu(0)) slideshow = [] i_start = 0 for i, h in enumerate(hashes): if h["distance"] > threshold and hashes[i-1]["frame_id"] - hashes[i_start]["frame_id"] > min_length: path=f'{FOLDER_PATH}/{vid_path.split("/")[-1].split(".")[0]}_{i_start}_{i-1}.png' Image.fromarray(vr[hashes[i-1]["frame_id"]].asnumpy()).save(path) slideshow.append({"slide": path, "start": i_start, "end": i-1}) i_start=i path=f'{FOLDER_PATH}/{vid_path.split("/")[-1].split(".")[0]}_{i_start}_{len(vr)-1}.png' Image.fromarray(vr[-1].asnumpy()).save(path) slideshow.append({"slide": path, "start": i_start, "end": len(vr)-1}) return [s["slide"] for s in slideshow] def trigger_plots(f2f_distance_plot, hashes, threshold): # if not hist_plot.get_config()["visible"] and len(hashes.get_config()["value"]) > 0 : return gradio.update(value=plot_hash_distance(hashes, threshold)) def set_visible(): return gradio.update(visible=True) demo = gradio.Blocks(analytics_enabled=True) with demo: with gradio.Row(): with gradio.Column(): with gradio.Row(): vid=gradio.Video(mirror_webcam=False) with gradio.Row(): btn_vid_proc = gradio.Button("Compute hashes") with gradio.Row(): hist_plot = gradio.Plot(label="Frame to frame hash distance histogram", visible=False) with gradio.Column(): hashes = gradio.JSON() with gradio.Column(visible=False) as result_row: btn_plot = gradio.Button("Plot & compute optimal threshold") threshold = gradio.Slider(minimum=1, maximum=30, value=5, label="Threshold") f2f_distance_plot = gradio.Plot(label="Frame to frame hash distance") btn_slides = gradio.Button("Extract Slides") with gradio.Row(): slideshow = gradio.Gallery(label="Extracted slides") slideshow.style(grid=6) btn_vid_proc.click(fn=compute_batch_hashes, inputs=[vid], outputs=[hashes]) hashes.change(fn=set_visible, inputs=[], outputs=[result_row]) btn_plot.click(fn=compute_threshold, inputs=[hashes], outputs=[threshold]) btn_plot.click(fn=trigger_plots, inputs=[f2f_distance_plot, hashes, threshold], outputs=[f2f_distance_plot]) threshold.change(fn=plot_hash_distance, inputs=[hashes, threshold], outputs=f2f_distance_plot) btn_slides.click(fn=get_slides, inputs=[vid, hashes, threshold], outputs=[slideshow]) demo.launch(cache_examples=False)