import gradio as gr import numpy as np import h5py import pandas as pd import pickle from skimage.util import random_noise import time import tensorflow as tf from model import VGGNet def query_normal(query_img: np.ndarray): query_feature = model.extract_feature(query_img, verbose=0) scores = np.dot(query_feature, features.T) rank_ID = np.argsort(scores)[::-1] rank_score = scores[rank_ID] imlist = [image_ids.astype(str)[index].replace('\\', '/') for i, index in enumerate(rank_ID[0:10])] return imlist, rank_score[0:10].tolist(), len(scores) def query_kmeans(query_img: np.ndarray): query_feature = model.extract_feature(query_img, verbose=0) cluster = kmeans.predict(query_feature.reshape(1, -1)) cluster = cluster[0] df = pd.DataFrame({'image_id': image_ids.astype(str), 'cluster_id': kmeans.labels_}) df = df[df['cluster_id'] == cluster] query = df[df['cluster_id'] == cluster].index query_feature = model.extract_feature(query_img) scores = np.dot(query_feature, features[query].T) rank_ID = np.argsort(scores)[::-1] rank_score = scores[rank_ID] imlist = [image_ids.astype(str)[index].replace('\\', '/') for i, index in enumerate(query[rank_ID[0:10]])] return imlist, rank_score[0:10].tolist(), len(scores), cluster def query(image, mode, noise, noise_seed, mean, var, amount, salt_vs_pepper): if image == None or noise_seed == None: return None, None, None, None, None query_img = tf.keras.utils.load_img(image, target_size=(model.input_shape[0], model.input_shape[1])) query_img = tf.keras.utils.img_to_array(query_img).astype(int) if noise == 'none': pass elif noise == 'gaussian' or noise =='speckle': query_img = random_noise(query_img / 255, mode=noise, rng=int(noise_seed), mean=mean, var=var, clip=True) query_img = np.array(query_img * 255, dtype=np.uint8) elif noise == 'localvar': query_img = random_noise(query_img / 255, mode=noise, rng=int(noise_seed), clip=True) query_img = np.array(query_img * 255, dtype=np.uint8) elif noise == 'poisson': query_img = random_noise(query_img / 255, mode=noise, rng=int(noise_seed), clip=True) query_img = np.array(query_img * 255, dtype=np.uint8) elif noise == 'salt' or noise == 'pepper': query_img = random_noise(query_img / 255, mode=noise, rng=int(noise_seed), amount=amount, clip=True) query_img = np.array(query_img * 255, dtype=np.uint8) elif noise == 's&p': query_img = random_noise(query_img / 255, mode=noise, rng=int(noise_seed), amount=amount, salt_vs_pepper=salt_vs_pepper, clip=True) query_img = np.array(query_img * 255, dtype=np.uint8) start = time.time() if mode == 'normal': results, scores, length = query_normal(query_img) elif mode == 'kmeans': results, scores, length, cluster = query_kmeans(query_img) end = time.time() query_time = end - start query_time = round(query_time * 1000, 2) return query_img, f'Query time: {query_time} ms', [(result, f'Score: {score}, file: {result}') for result, score in zip(results, scores)], length, f'{cluster}' if mode == 'kmeans' else 'None' if __name__ == '__main__': model = VGGNet() # Load dataset datasets = h5py.File('features.h5', 'r') features = datasets['features'][:] image_ids = datasets['image_ids'][:] datasets.close() # Load kmeans model with open('kmeans.pkl', 'rb') as f: kmeans = pickle.load(f) # Run web app iface = gr.Interface( title='An image search engine based on VGG16', description='dataset: https://www.kaggle.com/datasets/hcfighting/tinyimagenet200. Error may occur when result showing because i\'d deleted some images in the dataset, please try another image to search.', fn=query, inputs=[ gr.Image(type='filepath'), gr.Radio(['normal', 'kmeans'], value='normal', info='normal: search all images, kmeans: search images in the same cluster'), gr.Radio(['none', 'gaussian', 'localvar', 'poisson', 'salt', 'pepper', 's&p', 'speckle'], value='none', info='noise type (default: none)'), gr.Number(label='random_seed', value=0, info='random seed for noise (default: 0)'), gr.Slider(label='mean', maximum=1, minimum=0, value=0, info='mean of the noise, works with gaussian and localvar (default: 0)'), gr.Slider(label='var', maximum=1, minimum=0, value=0.01, info='variance of the noise, works with gaussian and localvar (default: 0.01)'), gr.Slider(label='amount', maximum=1, minimum=0, value=0.05, info='proportion of image pixels to replace with noise, works with salt, pepper, and s&p (default: 0.05)'), gr.Slider(label='salt_vs_pepper', maximum=1, minimum=0, value=0.5, info='proportion of s&p noise (default: 0.5)'), ], outputs=[ gr.Image(label='query image', type='numpy'), gr.Label(label='query time'), gr.Gallery(label='Top 10 similar images').style(columns=[5], rows=[2]), gr.Label(label='count of images for searching'), gr.Label(label='cluster id') ], ) iface.launch()