slideshow_extraction / phash_jax.py
gigant's picture
Upload phash_jax.py
79b8121
import jax
import jax.numpy as jnp
def convert_L(image):
#convert image to greyscale using the ITU-R 601-2 luma transform
# PIL.Image convert('L') method actually uses Floyd-Steinberg dithering
return jnp.maximum(jnp.minimum(image[:,:,0] * 0.299 + image[:,:,1] * 0.587 + image[:,:,2] * 0.114, 255), 0).astype("uint8")
def phash_jax(image, hash_size=8, highfreq_factor=4):
img_size = hash_size * highfreq_factor
image = jax.image.resize(convert_L(image), [img_size, img_size], "lanczos3") #convert to greyscale
dct = jax.scipy.fft.dct(jax.scipy.fft.dct(image, axis=0), axis=1)
dctlowfreq = dct[:hash_size, :hash_size]
med = jnp.median(dctlowfreq)
diff = dctlowfreq > med
return diff
def hash_dist(h1, h2):
return jnp.count_nonzero(h1.flatten() != h2.flatten())
batch_phash = jax.vmap(jax.jit(phash_jax))