File size: 4,699 Bytes
d0e55ca
 
7d6e745
d0e55ca
 
 
 
 
 
44356bd
 
d0e55ca
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
18b0fa3
d0e55ca
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c317d7c
d0e55ca
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c7fe8c9
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
BATCH_SIZE = 64
DOWNSAMPLE = 24
FOLDER_PATH = "."

import phash_jax
import jax.numpy as jnp
import matplotlib.pyplot as plt
from PIL import Image
import statistics
from decord import VideoReader
from decord import cpu
import gradio

def binary_array_to_hex(arr):
	"""
	Function to make a hex string out of a binary array.
	"""
	bit_string = ''.join(str(b) for b in 1 * arr.flatten())
	width = int(jnp.ceil(len(bit_string) / 4))
	return '{:0>{width}x}'.format(int(bit_string, 2), width=width)

def compute_batch_hashes(vid_path):
  kwargs={"width": 64, "height":64}
  vr = VideoReader(vid_path, ctx=cpu(0), **kwargs)
  hashes = []
  h_prev = None
  batch = []
  for i in range(0, len(vr), DOWNSAMPLE * BATCH_SIZE):
      print(f"batch_{i}")
      ids = [id for id in range(i, min(i + DOWNSAMPLE * BATCH_SIZE, len(vr)), DOWNSAMPLE)]
      vr.seek(0)
      batch = jnp.array(vr.get_batch(ids).asnumpy())
      batch_h =  phash_jax.batch_phash(batch)
      for i in range(len(ids)):
        h = batch_h[i]
        if h_prev == None:
          h_prev=h
        hashes.append({"frame_id":ids[i], "hash": binary_array_to_hex(h), "distance": int(phash_jax.hash_dist(h, h_prev))})
        h_prev = h
  return gradio.update(value=hashes, visible=False)

def plot_hash_distance(hashes, threshold):
  fig = plt.figure()
  ids = [h["frame_id"] for h in hashes]
  distances = [h["distance"] for h in hashes]
  plt.plot(ids, distances, ".")
  plt.plot(ids, [threshold]* len(ids), "r-")
  return fig

def compute_threshold(hashes):
  min_length = 24 * 3
  ids = [h["frame_id"] for h in hashes]
  distances = [h["distance"] for h in hashes]
  thrs_ = sorted(list(set(distances)),reverse=True)
  best = thrs_[0] - 1
  for threshold in thrs_[1:]:
    durations = []
    i_start=0
    for i, h in enumerate(hashes):
      if h["distance"] > threshold and hashes[i-1]["frame_id"] -  hashes[i_start]["frame_id"] > min_length:
        durations.append(hashes[i-1]["frame_id"] - hashes[i_start]["frame_id"])
        i_start=i
    if len(durations) < (len(hashes) * DOWNSAMPLE / 24) / 20:
      best = threshold
  return best

def get_slides(vid_path, hashes, threshold):
    min_length = 24 * 1.5
    vr = VideoReader(vid_path, ctx=cpu(0))
    slideshow = []
    i_start = 0
    for i, h in enumerate(hashes):
        if h["distance"] > threshold and hashes[i-1]["frame_id"] -  hashes[i_start]["frame_id"] > min_length:
            path=f'{FOLDER_PATH}/{vid_path.split("/")[-1].split(".")[0]}_{i_start}_{i-1}.png'
            Image.fromarray(vr[hashes[i-1]["frame_id"]].asnumpy()).save(path)
            slideshow.append({"slide": path, "start": i_start, "end": i-1})
            i_start=i
    path=f'{FOLDER_PATH}/{vid_path.split("/")[-1].split(".")[0]}_{i_start}_{len(vr)-1}.png'
    Image.fromarray(vr[-1].asnumpy()).save(path)
    slideshow.append({"slide": path, "start": i_start, "end": len(vr)-1})
    return [s["slide"] for s in slideshow]

def trigger_plots(f2f_distance_plot, hashes, threshold):
  # if not hist_plot.get_config()["visible"] and len(hashes.get_config()["value"]) > 0 :
  return gradio.update(value=plot_hash_distance(hashes, threshold))

def set_visible():
  return gradio.update(visible=True)
  
demo = gradio.Blocks()

with demo:
    with gradio.Row():
        with gradio.Column():
          with gradio.Row():
            vid=gradio.Video(mirror_webcam=False)
          with gradio.Row():
            btn_vid_proc = gradio.Button("Compute hashes")
          with gradio.Row():
            hist_plot = gradio.Plot(label="Frame to frame hash distance histogram", visible=False)
        with gradio.Column():
            hashes = gradio.JSON()
            with gradio.Column(visible=False) as result_row:
              btn_plot = gradio.Button("Plot & compute optimal threshold")
              threshold = gradio.Slider(minimum=1, maximum=30, value=5, label="Threshold")
              f2f_distance_plot = gradio.Plot(label="Frame to frame hash distance")
              btn_slides = gradio.Button("Extract Slides")
    with gradio.Row():
      slideshow = gradio.Gallery(label="Extracted slides")
      slideshow.style(grid=6)
    btn_vid_proc.click(fn=compute_batch_hashes, inputs=[vid], outputs=[hashes])
    hashes.change(fn=set_visible, inputs=[], outputs=[result_row])
    btn_plot.click(fn=compute_threshold, inputs=[hashes], outputs=[threshold])
    btn_plot.click(fn=trigger_plots, inputs=[f2f_distance_plot, hashes, threshold], outputs=[f2f_distance_plot])
    threshold.change(fn=plot_hash_distance, inputs=[hashes, threshold], outputs=f2f_distance_plot)
    btn_slides.click(fn=get_slides, inputs=[vid, hashes, threshold], outputs=[slideshow])

demo.queue(default_enabled=True).launch()