import io import csv import sys import pickle from collections import Counter import numpy as np import gradio as gr import gdown import torchvision from torchvision.datasets import ImageFolder from PIL import Image from SimSearch import FaissCosineNeighbors, SearchableTrainingSet from ExtractEmbedding import QueryToEmbedding from CHMCorr import chm_classify_and_visualize from visualization import plot_from_reranker_output csv.field_size_limit(sys.maxsize) concat = lambda x: np.concatenate(x, axis=0) # Embeddings gdown.cached_download( url="https://static.taesiri.com/chm-corr/embeddings.pickle", path="./embeddings.pickle", quiet=False, md5="002b2a7f5c80d910b9cc740c2265f058", ) # embeddings # gdown.download(id="116CiA_cXciGSl72tbAUDoN-f1B9Frp89") # labels gdown.download(id="1SDtq6ap7LPPpYfLbAxaMGGmj0EAV_m_e") # CUB training set gdown.cached_download( url="https://static.taesiri.com/chm-corr/CUB_train.zip", path="./CUB_train.zip", quiet=False, md5="1bd99e73b2fea8e4c2ebcb0e7722f1b1", ) # EXTRACT training set torchvision.datasets.utils.extract_archive( from_path="CUB_train.zip", to_path="data/", remove_finished=False, ) # CHM Weights gdown.cached_download( url="https://static.taesiri.com/chm-corr/pas_psi.pt", path="pas_psi.pt", quiet=False, md5="6b7b4d7bad7f89600fac340d6aa7708b", ) # Caluclate Accuracy with open(f"./embeddings.pickle", "rb") as f: Xtrain = pickle.load(f) # FIXME: re-run the code to get the embeddings in the right format with open(f"./labels.pickle", "rb") as f: ytrain = pickle.load(f) searcher = SearchableTrainingSet(Xtrain, ytrain) searcher.build_index() # Extract label names training_folder = ImageFolder(root="./data/train/") id_to_bird_name = { x[1]: x[0].split("/")[-2].replace(".", " ") for x in training_folder.imgs } def search(query_image, draw_arcs, searcher=searcher): query_embedding = QueryToEmbedding(query_image) scores, indices, labels = searcher.search(query_embedding, k=50) result_ctr = Counter(labels[0][:20]).most_common(5) top1_label = result_ctr[0][0] top_indices = [] for a, b in zip(labels[0][:20], indices[0][:20]): if a == top1_label: top_indices.append(b) gallery_images = [training_folder.imgs[int(X)][0] for X in top_indices[:5]] predicted_labels = {id_to_bird_name[X[0]]: X[1] / 20.0 for X in result_ctr} # CHM Prediction kNN_results = (top1_label, result_ctr[0][1], gallery_images) support_files = [training_folder.imgs[int(X)][0] for X in indices[0]] support_labels = [training_folder.imgs[int(X)][1] for X in indices[0]] support = [support_files, support_labels] chm_output = chm_classify_and_visualize( query_image, kNN_results, support, training_folder ) fig = plot_from_reranker_output(chm_output, draw_arcs=draw_arcs) # Resize the output img_buf = io.BytesIO() fig.savefig(img_buf, format="jpg") image = Image.open(img_buf) width, height = image.size new_width = width new_height = height left = (width - new_width) / 2 top = (height - new_height) / 2 right = (width + new_width) / 2 bottom = (height + new_height) / 2 viz_image = image.crop((left + 540, top + 40, right - 492, bottom - 100)) return viz_image, predicted_labels blocks = gr.Blocks() with blocks: gr.Markdown(""" # CHM-Corr DEMO""") gr.Markdown(""" ### Parameters: N=50, k=20 - Using ResNet50 features""") # with gr.Row(): input_image = gr.Image(type="filepath") with gr.Column(): arcs_checkbox = gr.Checkbox(label="Draw Arcs") run_btn = gr.Button("Classify") # with gr.Column(): gr.Markdown(""" ### CHM-Corr Output """) viz_plot = gr.Image(type="pil") gr.Markdown(""" ### kNN Predicted Labels """) predicted_labels = gr.Label(label="kNN Prediction") run_btn.click( search, inputs=[input_image, arcs_checkbox], outputs=[viz_plot, predicted_labels], ) if __name__ == "__main__": blocks.launch( debug=True, enable_queue=True, )