File size: 5,291 Bytes
7b7de41
16348d6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
53225e3
16348d6
 
 
 
 
 
 
53225e3
16348d6
 
 
 
53225e3
16348d6
 
 
 
53225e3
16348d6
 
 
 
53225e3
16348d6
53225e3
 
 
16348d6
 
 
 
 
 
 
 
 
53225e3
 
 
 
 
 
 
 
16348d6
 
 
 
 
 
 
53225e3
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
import gradio as gr
import numpy as np
import tensorflow as tf
import logging
from PIL import Image
from tensorflow.keras.preprocessing import image as keras_image
from tensorflow.keras.applications.resnet50 import ResNet50, preprocess_input as resnet_preprocess
from tensorflow.keras.applications.vgg16 import VGG16, preprocess_input as vgg_preprocess
import scipy.fftpack
import time
import clip
import torch

# Set up logging
logging.basicConfig(level=logging.INFO)

# Load models
resnet_model = ResNet50(weights='imagenet', include_top=False, pooling='avg')
vgg_model = VGG16(weights='imagenet', include_top=False, pooling='avg')
clip_model, preprocess_clip = clip.load("ViT-B/32", device="cpu")

# Preprocess function
def preprocess_img(img_path, target_size=(224, 224), preprocess_func=resnet_preprocess):
    start_time = time.time()
    img = keras_image.load_img(img_path, target_size=target_size)
    img_array = keras_image.img_to_array(img)
    img_array = np.expand_dims(img_array, axis=0)
    img_array = preprocess_func(img_array)
    logging.info(f"Image preprocessed in {time.time() - start_time:.4f} seconds")
    return img_array

# Feature extraction function
def extract_features(img_path, model, preprocess_func):
    img_array = preprocess_img(img_path, preprocess_func=preprocess_func)
    start_time = time.time()
    features = model.predict(img_array)
    logging.info(f"Features extracted in {time.time() - start_time:.4f} seconds")
    return features.flatten()

# Calculate cosine similarity
def cosine_similarity(vec1, vec2):
    return np.dot(vec1, vec2) / (np.linalg.norm(vec1) * np.linalg.norm(vec2))

# pHash related functions
def phashstr(image, hash_size=8, highfreq_factor=4):
    img_size = hash_size * highfreq_factor
    image = image.convert('L').resize((img_size, img_size), Image.Resampling.LANCZOS)
    pixels = np.asarray(image)
    dct = scipy.fftpack.dct(scipy.fftpack.dct(pixels, axis=0), axis=1)
    dctlowfreq = dct[:hash_size, :hash_size]
    med = np.median(dctlowfreq)
    diff = dctlowfreq > med
    return _binary_array_to_hex(diff.flatten())

def _binary_array_to_hex(arr):
    h = 0
    s = []
    for i, v in enumerate(arr):
        if v:
            h += 2**(i % 8)
        if (i % 8) == 7:
            s.append(hex(h)[2:].rjust(2, '0'))
            h = 0
    return ''.join(s)

def hamming_distance(hash1, hash2):
    if len(hash1) != len(hash2):
        raise ValueError("Hashes must be of the same length")
    return sum(c1 != c2 for c1, c2 in zip(hash1, hash2))

def hamming_to_similarity(distance, hash_length):
    return (1 - distance / hash_length) * 100

# CLIP related functions
def extract_clip_features(image_path, model, preprocess):
    image = preprocess(Image.open(image_path)).unsqueeze(0).to("cpu")
    with torch.no_grad():
        features = model.encode_image(image)
    return features.cpu().numpy().flatten()

# Main function
def compare_images(image1, image2, method):
    similarity = None
    start_time = time.time()
    
    if method == 'pHash':
        img1 = Image.open(image1)
        img2 = Image.open(image2)
        hash1 = phashstr(img1)
        hash2 = phashstr(img2)
        distance = hamming_distance(hash1, hash2)
        similarity = hamming_to_similarity(distance, len(hash1) * 4)
        
    elif method == 'ResNet50':
        features1 = extract_features(image1, resnet_model, resnet_preprocess)
        features2 = extract_features(image2, resnet_model, resnet_preprocess)
        similarity = cosine_similarity(features1, features2)
        
    elif method == 'VGG16':
        features1 = extract_features(image1, vgg_model, vgg_preprocess)
        features2 = extract_features(image2, vgg_model, vgg_preprocess)
        similarity = cosine_similarity(features1, features2)
        
    elif method == 'CLIP':
        features1 = extract_clip_features(image1, clip_model, preprocess_clip)
        features2 = extract_clip_features(image2, clip_model, preprocess_clip)
        similarity = cosine_similarity(features1, features2)

    logging.info(f"AI based Supporting Documents comparison using {method} completed in {time.time() - start_time:.4f} seconds")
    
    # Return similarity with HTML formatting for bold and colorful text
    return f"<span style='font-weight:bold; color:#4CAF50;'>Similarity Score: {similarity:.2f}%</span>"

# Gradio interface
demo = gr.Interface(
    fn=compare_images,
    inputs=[
        gr.Image(type="filepath", label="Upload First Image"),
        gr.Image(type="filepath", label="Upload Second Image"),
        gr.Radio(["pHash", "ResNet50", "VGG16", "CLIP"], label="Select Comparison Method")
    ],
    outputs=gr.HTML(label="Similarity"),  # Use HTML for bold and colorful text
    title="AI Based Customs Supporting Documents Comparison",
    description=(
        "Upload two images of supporting documents and select the comparison method.\n"
        "Fraud documents like invoices are used by custom brokers with the same templates. "
        "This tool helps identify similar document templates used in two different consignments.\n"
        "Developed by NCTC."
    ),
    examples=[
        ["Snipaste_2024-05-31_16-18-31.jpg", "Snipaste_2024-05-31_16-18-52.jpg"],
        ["example1.png", "example2.png"]
    ]
)

demo.launch()