will33am's picture
Update app.py
757440e
from datasets import load_dataset
from torchvision import transforms
import torch
from timm import create_model
from omegaconf import OmegaConf
import faiss
import pickle
import gradio as gr
import os
import joblib
import torch.nn as nn
from typing import Dict, Iterable, Callable
from torch import Tensor
import torchvision
from PIL import Image
def get_model(args,arch,load_from,arch_path):
if load_from == 'timm':
model = create_model(arch,pretrained = True).to(args.PARAMETERS.device)
print("Load model timm")
elif load_from == 'torchvision':
if arch == 'resnet50':
model = torchvision.models.resnet50(pretrained=False)
if len(arch_path)>0:
print("Loading pretrained Model")
model.load_state_dict(torch.load(arch_path,map_location='cpu')['state_dict'],strict = True)
model.eval()
return model
def get_transform(args):
return transforms.Compose([transforms.Resize([args.PARAMETERS.img_resize,args.PARAMETERS.img_resize]),
transforms.CenterCrop([args.PARAMETERS.img_crop,args.PARAMETERS.img_crop]),
transforms.ToTensor()])
class FeatureExtractor(nn.Module):
def __init__(self, model: nn.Module, layers: Iterable[str]):
super().__init__()
self.model = model
self.layers = layers
self._features = {layer: torch.empty(0) for layer in layers}
for layer_id in layers:
layer = dict([*self.model.named_modules()])[layer_id]
layer.register_forward_hook(self.save_outputs_hook(layer_id))
def save_outputs_hook(self, layer_id: str) -> Callable:
def fn(_, __, output):
self._features[layer_id] = output
return fn
def forward(self, x: Tensor) -> Dict[str, Tensor]:
_ = self.model(x)
return self._features
def _load_dataset(args):
if args.PARAMETERS.metric == 'L2':
faiss_metric = faiss.METRIC_L2
dataset = load_dataset(args.PARAMETERS.dataset,split = 'train')
dataset = dataset.add_faiss_index(column=args.ROBUST.embedding_col,metric_type = faiss_metric)
dataset = dataset.add_faiss_index(column=args.NONROBUST.embedding_col,metric_type = faiss_metric)
return dataset
args = OmegaConf.load("configs/resnet.yaml")
wiki_dataset = _load_dataset(args)
TRANSFORMS = get_transform(args)
robust_model = get_model(args,args.ROBUST.arch,args.ROBUST.load_from,args.ROBUST.arch_path)
non_robust_model = get_model(args,args.NONROBUST.arch,args.NONROBUST.load_from,args.NONROBUST.arch_path)
fe_robust_model = FeatureExtractor(robust_model,layers = [args.ROBUST.layer])
fe_nonrobust_model = FeatureExtractor(non_robust_model,layers = [args.NONROBUST.layer])
# +
def retrieval_fn(image,radio):
#try:
image = Image.fromarray(image)
#except:
#pass
image = TRANSFORMS(image).unsqueeze(0)
image = image.to(args.PARAMETERS.device)
if radio == 'robust':
emb = fe_robust_model(image)[args.ROBUST.layer]
emb = emb.view(1,-1).detach().cpu().numpy()
scores, retrieved_examples = wiki_dataset.get_nearest_examples(index_name = args.ROBUST.embedding_col,
query = emb,
k = 3)
elif radio == 'standard':
emb = fe_nonrobust_model(image)[args.NONROBUST.layer]
emb = emb.view(1,-1).detach().cpu().numpy()
scores, retrieved_examples = wiki_dataset.get_nearest_examples(index_name = args.NONROBUST.embedding_col,
query = emb,
k=3)
return scores,retrieved_examples
def gradio_fn(image,radio):
scores,retrieved_examples = retrieval_fn(image,radio)
m = []
for description,image,score in zip(retrieved_examples['description'],
retrieved_examples['image'],
scores):
m.append(description)
m.append(image)
return m
# -
if __name__ == '__main__':
demo = gr.Blocks()
with demo:
gr.Markdown("# Robust vs Standard Image Retrieval")
with gr.Tabs():
with gr.TabItem("Upload your Image"):
with gr.Row():
with gr.Column():
with gr.Row():
image_input = gr.Image(label="Input Image")
with gr.Row():
radio_button = gr.Radio(["robust","standard"],
value = "robust",
label = "OD Model")
with gr.Row():
calculate_button = gr.Button("Compute")
with gr.Column():
textbox1 = gr.Textbox(label = "Artist / Title / Style / Genre / Date")
output_image1 = gr.Image(label="1st Best match")
textbox2 = gr.Textbox(label = "Artist / Title / Style / Genre / Date")
output_image2 = gr.Image(label="2nd Best match")
textbox3 = gr.Textbox(label = "Artist / Title / Style / Genre / Date")
output_image3 = gr.Image(label="3rd Best match")
calculate_button.click(fn = gradio_fn,
inputs = [image_input,radio_button],
outputs = [textbox1,output_image1,textbox2,output_image2,textbox3,output_image3])
demo.launch(share = False,debug = True)