File size: 830 Bytes
79b8121
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
import jax
import jax.numpy as jnp

def convert_L(image):
  #convert image to greyscale using the ITU-R 601-2 luma transform
  # PIL.Image convert('L') method actually uses Floyd-Steinberg dithering
  return jnp.maximum(jnp.minimum(image[:,:,0] * 0.299 + image[:,:,1] * 0.587 + image[:,:,2] * 0.114, 255), 0).astype("uint8")
  
def phash_jax(image, hash_size=8, highfreq_factor=4):
  img_size = hash_size * highfreq_factor
  image = jax.image.resize(convert_L(image), [img_size, img_size], "lanczos3") #convert to greyscale
  dct = jax.scipy.fft.dct(jax.scipy.fft.dct(image, axis=0), axis=1)
  dctlowfreq = dct[:hash_size, :hash_size]
  med = jnp.median(dctlowfreq)
  diff = dctlowfreq > med
  return diff

def hash_dist(h1, h2):
  return jnp.count_nonzero(h1.flatten() != h2.flatten())

batch_phash = jax.vmap(jax.jit(phash_jax))