Spaces:
Runtime error
Runtime error
import gradio as gr | |
import numpy as np | |
import h5py | |
import pandas as pd | |
import pickle | |
from skimage.util import random_noise | |
import time | |
import tensorflow as tf | |
from model import VGGNet | |
def query_normal(query_img: np.ndarray): | |
query_feature = model.extract_feature(query_img, verbose=0) | |
scores = np.dot(query_feature, features.T) | |
rank_ID = np.argsort(scores)[::-1] | |
rank_score = scores[rank_ID] | |
imlist = [image_ids.astype(str)[index] for i, index in enumerate(rank_ID[0:10])] | |
return imlist, rank_score[0:10].tolist(), len(scores) | |
def query_kmeans(query_img: np.ndarray): | |
query_feature = model.extract_feature(query_img, verbose=0) | |
cluster = kmeans.predict(query_feature.reshape(1, -1)) | |
cluster = cluster[0] | |
df = pd.DataFrame({'image_id': image_ids.astype(str), 'cluster_id': kmeans.labels_}) | |
df = df[df['cluster_id'] == cluster] | |
query = df[df['cluster_id'] == cluster].index | |
query_feature = model.extract_feature(query_img) | |
scores = np.dot(query_feature, features[query].T) | |
rank_ID = np.argsort(scores)[::-1] | |
rank_score = scores[rank_ID] | |
imlist = [image_ids.astype(str)[index] for i, index in enumerate(query[rank_ID[0:10]])] | |
return imlist, rank_score[0:10].tolist(), len(scores), cluster | |
def query(image, mode, noise, noise_seed, mean, var, amount, salt_vs_pepper): | |
if image == None or noise_seed == None: | |
return None, None, None, None, None | |
query_img = tf.keras.utils.load_img(image, target_size=(model.input_shape[0], model.input_shape[1])) | |
query_img = tf.keras.utils.img_to_array(query_img).astype(int) | |
if noise == 'none': | |
pass | |
elif noise == 'gaussian' or noise =='speckle': | |
query_img = random_noise(query_img / 255, mode=noise, rng=int(noise_seed), mean=mean, var=var, clip=True) | |
query_img = np.array(query_img * 255, dtype=np.uint8) | |
elif noise == 'localvar': | |
query_img = random_noise(query_img / 255, mode=noise, rng=int(noise_seed), clip=True) | |
query_img = np.array(query_img * 255, dtype=np.uint8) | |
elif noise == 'poisson': | |
query_img = random_noise(query_img / 255, mode=noise, rng=int(noise_seed), clip=True) | |
query_img = np.array(query_img * 255, dtype=np.uint8) | |
elif noise == 'salt' or noise == 'pepper': | |
query_img = random_noise(query_img / 255, mode=noise, rng=int(noise_seed), amount=amount, clip=True) | |
query_img = np.array(query_img * 255, dtype=np.uint8) | |
elif noise == 's&p': | |
query_img = random_noise(query_img / 255, mode=noise, rng=int(noise_seed), amount=amount, salt_vs_pepper=salt_vs_pepper, clip=True) | |
query_img = np.array(query_img * 255, dtype=np.uint8) | |
start = time.time() | |
if mode == 'normal': | |
results, scores, length = query_normal(query_img) | |
elif mode == 'kmeans': | |
results, scores, length, cluster = query_kmeans(query_img) | |
end = time.time() | |
query_time = end - start | |
query_time = round(query_time * 1000, 2) | |
return query_img, f'Query time: {query_time} ms', [(result, f'Score: {score}, file: {result}') for result, score in zip(results, scores)], length, f'{cluster}' if mode == 'kmeans' else 'None' | |
if __name__ == '__main__': | |
model = VGGNet() | |
# Load dataset | |
datasets = h5py.File('features.h5', 'r') | |
features = datasets['features'][:] | |
image_ids = datasets['image_ids'][:] | |
datasets.close() | |
# Load kmeans model | |
with open('kmeans.pkl', 'rb') as f: | |
kmeans = pickle.load(f) | |
# Run web app | |
iface = gr.Interface( | |
title='An image search engine based on VGG16', | |
description='dataset: https://www.kaggle.com/datasets/hcfighting/tinyimagenet200. Error may occur when result showing because i\'d deleted some images in the dataset, please try another image to search.', | |
fn=query, | |
inputs=[ | |
gr.Image(type='filepath'), | |
gr.Radio(['normal', 'kmeans'], value='normal', info='normal: search all images, kmeans: search images in the same cluster'), | |
gr.Radio(['none', 'gaussian', 'localvar', 'poisson', 'salt', 'pepper', 's&p', 'speckle'], value='none', info='noise type (default: none)'), | |
gr.Number(label='random_seed', value=0, info='random seed for noise (default: 0)'), | |
gr.Slider(label='mean', maximum=1, minimum=0, value=0, info='mean of the noise, works with gaussian and localvar (default: 0)'), | |
gr.Slider(label='var', maximum=1, minimum=0, value=0.01, info='variance of the noise, works with gaussian and localvar (default: 0.01)'), | |
gr.Slider(label='amount', maximum=1, minimum=0, value=0.05, info='proportion of image pixels to replace with noise, works with salt, pepper, and s&p (default: 0.05)'), | |
gr.Slider(label='salt_vs_pepper', maximum=1, minimum=0, value=0.5, info='proportion of s&p noise (default: 0.5)'), | |
], | |
outputs=[ | |
gr.Image(label='query image', type='numpy'), | |
gr.Label(label='query time'), | |
gr.Gallery(label='Top 10 similar images').style(columns=[5], rows=[2]), | |
gr.Label(label='count of images for searching'), | |
gr.Label(label='cluster id') | |
], | |
) | |
iface.launch() |