from datasets import load_dataset from torchvision import transforms import torch from timm import create_model from omegaconf import OmegaConf import faiss import pickle import gradio as gr import os import joblib import torch.nn as nn from typing import Dict, Iterable, Callable from torch import Tensor import torchvision from PIL import Image def get_model(args,arch,load_from,arch_path): if load_from == 'timm': model = create_model(arch,pretrained = True).to(args.PARAMETERS.device) print("Load model timm") elif load_from == 'torchvision': if arch == 'resnet50': model = torchvision.models.resnet50(pretrained=False) if len(arch_path)>0: print("Loading pretrained Model") model.load_state_dict(torch.load(arch_path,map_location='cpu')['state_dict'],strict = True) model.eval() return model def get_transform(args): return transforms.Compose([transforms.Resize([args.PARAMETERS.img_resize,args.PARAMETERS.img_resize]), transforms.CenterCrop([args.PARAMETERS.img_crop,args.PARAMETERS.img_crop]), transforms.ToTensor()]) class FeatureExtractor(nn.Module): def __init__(self, model: nn.Module, layers: Iterable[str]): super().__init__() self.model = model self.layers = layers self._features = {layer: torch.empty(0) for layer in layers} for layer_id in layers: layer = dict([*self.model.named_modules()])[layer_id] layer.register_forward_hook(self.save_outputs_hook(layer_id)) def save_outputs_hook(self, layer_id: str) -> Callable: def fn(_, __, output): self._features[layer_id] = output return fn def forward(self, x: Tensor) -> Dict[str, Tensor]: _ = self.model(x) return self._features def _load_dataset(args): if args.PARAMETERS.metric == 'L2': faiss_metric = faiss.METRIC_L2 dataset = load_dataset(args.PARAMETERS.dataset,split = 'train') dataset = dataset.add_faiss_index(column=args.ROBUST.embedding_col,metric_type = faiss_metric) dataset = dataset.add_faiss_index(column=args.NONROBUST.embedding_col,metric_type = faiss_metric) return dataset args = OmegaConf.load("configs/resnet.yaml") wiki_dataset = _load_dataset(args) TRANSFORMS = get_transform(args) robust_model = get_model(args,args.ROBUST.arch,args.ROBUST.load_from,args.ROBUST.arch_path) non_robust_model = get_model(args,args.NONROBUST.arch,args.NONROBUST.load_from,args.NONROBUST.arch_path) fe_robust_model = FeatureExtractor(robust_model,layers = [args.ROBUST.layer]) fe_nonrobust_model = FeatureExtractor(non_robust_model,layers = [args.NONROBUST.layer]) # + def retrieval_fn(image,radio): #try: image = Image.fromarray(image) #except: #pass image = TRANSFORMS(image).unsqueeze(0) image = image.to(args.PARAMETERS.device) if radio == 'robust': emb = fe_robust_model(image)[args.ROBUST.layer] emb = emb.view(1,-1).detach().cpu().numpy() scores, retrieved_examples = wiki_dataset.get_nearest_examples(index_name = args.ROBUST.embedding_col, query = emb, k = 3) elif radio == 'standard': emb = fe_nonrobust_model(image)[args.NONROBUST.layer] emb = emb.view(1,-1).detach().cpu().numpy() scores, retrieved_examples = wiki_dataset.get_nearest_examples(index_name = args.NONROBUST.embedding_col, query = emb, k=3) return scores,retrieved_examples def gradio_fn(image,radio): scores,retrieved_examples = retrieval_fn(image,radio) m = [] for description,image,score in zip(retrieved_examples['description'], retrieved_examples['image'], scores): m.append(description) m.append(image) return m # - if __name__ == '__main__': demo = gr.Blocks() with demo: gr.Markdown("# Robust vs Standard Image Retrieval") with gr.Tabs(): with gr.TabItem("Upload your Image"): with gr.Row(): with gr.Column(): with gr.Row(): image_input = gr.Image(label="Input Image") with gr.Row(): radio_button = gr.Radio(["robust","standard"], value = "robust", label = "OD Model") with gr.Row(): calculate_button = gr.Button("Compute") with gr.Column(): textbox1 = gr.Textbox(label = "Artist / Title / Style / Genre / Date") output_image1 = gr.Image(label="1st Best match") textbox2 = gr.Textbox(label = "Artist / Title / Style / Genre / Date") output_image2 = gr.Image(label="2nd Best match") textbox3 = gr.Textbox(label = "Artist / Title / Style / Genre / Date") output_image3 = gr.Image(label="3rd Best match") calculate_button.click(fn = gradio_fn, inputs = [image_input,radio_button], outputs = [textbox1,output_image1,textbox2,output_image2,textbox3,output_image3]) demo.launch(share = False,debug = True)