NCTCMumbai commited on
Commit
16348d6
1 Parent(s): a4fb644

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +125 -125
app.py CHANGED
@@ -1,125 +1,125 @@
1
- import gradio as gr
2
- import numpy as np
3
- import tensorflow as tf
4
- import logging
5
- from PIL import Image
6
- from tensorflow.keras.preprocessing import image as keras_image
7
- from tensorflow.keras.applications.resnet50 import ResNet50, preprocess_input as resnet_preprocess
8
- from tensorflow.keras.applications.vgg16 import VGG16, preprocess_input as vgg_preprocess
9
- import scipy.fftpack
10
- import time
11
- import clip
12
- import torch
13
-
14
- # Set up logging
15
- logging.basicConfig(level=logging.INFO)
16
-
17
- # Load models
18
- resnet_model = ResNet50(weights='imagenet', include_top=False, pooling='avg')
19
- vgg_model = VGG16(weights='imagenet', include_top=False, pooling='avg')
20
- clip_model, preprocess_clip = clip.load("ViT-B/32", device="cpu")
21
-
22
- # Preprocess function
23
- def preprocess_img(img_path, target_size=(224, 224), preprocess_func=resnet_preprocess):
24
- start_time = time.time()
25
- img = keras_image.load_img(img_path, target_size=target_size)
26
- img_array = keras_image.img_to_array(img)
27
- img_array = np.expand_dims(img_array, axis=0)
28
- img_array = preprocess_func(img_array)
29
- logging.info(f"Image preprocessed in {time.time() - start_time:.4f} seconds")
30
- return img_array
31
-
32
- # Feature extraction function
33
- def extract_features(img_path, model, preprocess_func):
34
- img_array = preprocess_img(img_path, preprocess_func=preprocess_func)
35
- start_time = time.time()
36
- features = model.predict(img_array)
37
- logging.info(f"Features extracted in {time.time() - start_time:.4f} seconds")
38
- return features.flatten()
39
-
40
- # Calculate cosine similarity
41
- def cosine_similarity(vec1, vec2):
42
- return np.dot(vec1, vec2) / (np.linalg.norm(vec1) * np.linalg.norm(vec2))
43
-
44
- # pHash related functions
45
- def phashstr(image, hash_size=8, highfreq_factor=4):
46
- img_size = hash_size * highfreq_factor
47
- image = image.convert('L').resize((img_size, img_size), Image.Resampling.LANCZOS)
48
- pixels = np.asarray(image)
49
- dct = scipy.fftpack.dct(scipy.fftpack.dct(pixels, axis=0), axis=1)
50
- dctlowfreq = dct[:hash_size, :hash_size]
51
- med = np.median(dctlowfreq)
52
- diff = dctlowfreq > med
53
- return _binary_array_to_hex(diff.flatten())
54
-
55
- def _binary_array_to_hex(arr):
56
- h = 0
57
- s = []
58
- for i, v in enumerate(arr):
59
- if v:
60
- h += 2**(i % 8)
61
- if (i % 8) == 7:
62
- s.append(hex(h)[2:].rjust(2, '0'))
63
- h = 0
64
- return ''.join(s)
65
-
66
- def hamming_distance(hash1, hash2):
67
- if len(hash1) != len(hash2):
68
- raise ValueError("Hashes must be of the same length")
69
- return sum(c1 != c2 for c1, c2 in zip(hash1, hash2))
70
-
71
- def hamming_to_similarity(distance, hash_length):
72
- return (1 - distance / hash_length) * 100
73
-
74
- # CLIP related functions
75
- def extract_clip_features(image_path, model, preprocess):
76
- image = preprocess(Image.open(image_path)).unsqueeze(0).to("cpu")
77
- with torch.no_grad():
78
- features = model.encode_image(image)
79
- return features.cpu().numpy().flatten()
80
-
81
- # Main function
82
- def compare_images(image1, image2, method):
83
- similarity = None
84
- start_time = time.time()
85
- if method == 'pHash':
86
- img1 = Image.open(image1)
87
- img2 = Image.open(image2)
88
- hash1 = phashstr(img1)
89
- hash2 = phashstr(img2)
90
- distance = hamming_distance(hash1, hash2)
91
- similarity = hamming_to_similarity(distance, len(hash1) * 4)
92
- elif method == 'ResNet50':
93
- features1 = extract_features(image1, resnet_model, resnet_preprocess)
94
- features2 = extract_features(image2, resnet_model, resnet_preprocess)
95
- similarity = cosine_similarity(features1, features2)
96
- elif method == 'VGG16':
97
- features1 = extract_features(image1, vgg_model, vgg_preprocess)
98
- features2 = extract_features(image2, vgg_model, vgg_preprocess)
99
- similarity = cosine_similarity(features1, features2)
100
- elif method == 'CLIP':
101
- features1 = extract_clip_features(image1, clip_model, preprocess_clip)
102
- features2 = extract_clip_features(image2, clip_model, preprocess_clip)
103
- similarity = cosine_similarity(features1, features2)
104
-
105
- logging.info(f"Image comparison using {method} completed in {time.time() - start_time:.4f} seconds")
106
- return similarity
107
-
108
- # Gradio interface
109
- demo = gr.Interface(
110
- fn=compare_images,
111
- inputs=[
112
- gr.Image(type="filepath", label="Upload First Image"),
113
- gr.Image(type="filepath", label="Upload Second Image"),
114
- gr.Radio(["pHash", "ResNet50", "VGG16", "CLIP"], label="Select Comparison Method")
115
- ],
116
- outputs=gr.Textbox(label="Similarity"),
117
- title="Image Similarity Comparison",
118
- description="Upload two images and select the comparison method.",
119
- examples=[
120
- ["Snipaste_2024-05-31_16-18-31.jpg", "Snipaste_2024-05-31_16-18-52.jpg"],
121
- ["example1.png", "example2.png"]
122
- ]
123
- )
124
-
125
- demo.launch()
 
1
+ import gradio as gr
2
+ import numpy as np
3
+ import tensorflow as tf
4
+ import logging
5
+ from PIL import Image
6
+ from tensorflow.keras.preprocessing import image as keras_image
7
+ from tensorflow.keras.applications.resnet50 import ResNet50, preprocess_input as resnet_preprocess
8
+ from tensorflow.keras.applications.vgg16 import VGG16, preprocess_input as vgg_preprocess
9
+ import scipy.fftpack
10
+ import time
11
+ import clip
12
+ import torch
13
+
14
+ # Set up logging
15
+ logging.basicConfig(level=logging.INFO)
16
+
17
+ # Load models
18
+ resnet_model = ResNet50(weights='imagenet', include_top=False, pooling='avg')
19
+ vgg_model = VGG16(weights='imagenet', include_top=False, pooling='avg')
20
+ clip_model, preprocess_clip = clip.load("ViT-B/32", device="cpu")
21
+
22
+ # Preprocess function
23
+ def preprocess_img(img_path, target_size=(224, 224), preprocess_func=resnet_preprocess):
24
+ start_time = time.time()
25
+ img = keras_image.load_img(img_path, target_size=target_size)
26
+ img_array = keras_image.img_to_array(img)
27
+ img_array = np.expand_dims(img_array, axis=0)
28
+ img_array = preprocess_func(img_array)
29
+ logging.info(f"Image preprocessed in {time.time() - start_time:.4f} seconds")
30
+ return img_array
31
+
32
+ # Feature extraction function
33
+ def extract_features(img_path, model, preprocess_func):
34
+ img_array = preprocess_img(img_path, preprocess_func=preprocess_func)
35
+ start_time = time.time()
36
+ features = model.predict(img_array)
37
+ logging.info(f"Features extracted in {time.time() - start_time:.4f} seconds")
38
+ return features.flatten()
39
+
40
+ # Calculate cosine similarity
41
+ def cosine_similarity(vec1, vec2):
42
+ return np.dot(vec1, vec2) / (np.linalg.norm(vec1) * np.linalg.norm(vec2))
43
+
44
+ # pHash related functions
45
+ def phashstr(image, hash_size=8, highfreq_factor=4):
46
+ img_size = hash_size * highfreq_factor
47
+ image = image.convert('L').resize((img_size, img_size), Image.Resampling.LANCZOS)
48
+ pixels = np.asarray(image)
49
+ dct = scipy.fftpack.dct(scipy.fftpack.dct(pixels, axis=0), axis=1)
50
+ dctlowfreq = dct[:hash_size, :hash_size]
51
+ med = np.median(dctlowfreq)
52
+ diff = dctlowfreq > med
53
+ return _binary_array_to_hex(diff.flatten())
54
+
55
+ def _binary_array_to_hex(arr):
56
+ h = 0
57
+ s = []
58
+ for i, v in enumerate(arr):
59
+ if v:
60
+ h += 2**(i % 8)
61
+ if (i % 8) == 7:
62
+ s.append(hex(h)[2:].rjust(2, '0'))
63
+ h = 0
64
+ return ''.join(s)
65
+
66
+ def hamming_distance(hash1, hash2):
67
+ if len(hash1) != len(hash2):
68
+ raise ValueError("Hashes must be of the same length")
69
+ return sum(c1 != c2 for c1, c2 in zip(hash1, hash2))
70
+
71
+ def hamming_to_similarity(distance, hash_length):
72
+ return (1 - distance / hash_length) * 100
73
+
74
+ # CLIP related functions
75
+ def extract_clip_features(image_path, model, preprocess):
76
+ image = preprocess(Image.open(image_path)).unsqueeze(0).to("cpu")
77
+ with torch.no_grad():
78
+ features = model.encode_image(image)
79
+ return features.cpu().numpy().flatten()
80
+
81
+ # Main function
82
+ def compare_images(image1, image2, method):
83
+ similarity = None
84
+ start_time = time.time()
85
+ if method == 'pHash':
86
+ img1 = Image.open(image1)
87
+ img2 = Image.open(image2)
88
+ hash1 = phashstr(img1)
89
+ hash2 = phashstr(img2)
90
+ distance = hamming_distance(hash1, hash2)
91
+ similarity = hamming_to_similarity(distance, len(hash1) * 4)
92
+ elif method == 'ResNet50':
93
+ features1 = extract_features(image1, resnet_model, resnet_preprocess)
94
+ features2 = extract_features(image2, resnet_model, resnet_preprocess)
95
+ similarity = cosine_similarity(features1, features2)
96
+ elif method == 'VGG16':
97
+ features1 = extract_features(image1, vgg_model, vgg_preprocess)
98
+ features2 = extract_features(image2, vgg_model, vgg_preprocess)
99
+ similarity = cosine_similarity(features1, features2)
100
+ elif method == 'CLIP':
101
+ features1 = extract_clip_features(image1, clip_model, preprocess_clip)
102
+ features2 = extract_clip_features(image2, clip_model, preprocess_clip)
103
+ similarity = cosine_similarity(features1, features2)
104
+
105
+ logging.info(f"AI based Supporting Documents comparison using {method} completed in {time.time() - start_time:.4f} seconds")
106
+ return similarity
107
+
108
+ # Gradio interface
109
+ demo = gr.Interface(
110
+ fn=compare_images,
111
+ inputs=[
112
+ gr.Image(type="filepath", label="Upload First Image"),
113
+ gr.Image(type="filepath", label="Upload Second Image"),
114
+ gr.Radio(["pHash", "ResNet50", "VGG16", "CLIP"], label="Select Comparison Method")
115
+ ],
116
+ outputs=gr.Textbox(label="Similarity"),
117
+ title="AI based Customs Supporting Documents comparison",
118
+ description="Upload two images of Suppporting documents and select the comparison method.",
119
+ examples=[
120
+ ["Snipaste_2024-05-31_16-18-31.jpg", "Snipaste_2024-05-31_16-18-52.jpg"],
121
+ ["example1.png", "example2.png"]
122
+ ]
123
+ )
124
+
125
+ demo.launch()