will33am commited on
Commit
068d302
1 Parent(s): 97b9646

initial commit

Browse files
.ipynb_checkpoints/app-checkpoint.py ADDED
@@ -0,0 +1,139 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from datasets import load_dataset
2
+ from torchvision import transforms
3
+ import torch
4
+ from timm import create_model
5
+ from omegaconf import OmegaConf
6
+ import faiss
7
+ import pickle
8
+ import gradio as gr
9
+ import os
10
+ import joblib
11
+ import torch.nn as nn
12
+ from typing import Dict, Iterable, Callable
13
+ from torch import Tensor
14
+ import torchvision
15
+ from PIL import Image
16
+
17
+
18
+ def get_model(args,arch,load_from,arch_path):
19
+ if load_from == 'timm':
20
+ model = create_model(arch,pretrained = True).to(args.PARAMETERS.device)
21
+ elif load_from == 'torchvision':
22
+ if arch == 'resnet50':
23
+ model = torchvision.models.resnet50(pretrained=False)
24
+ if len(arch_path)>0:
25
+ print("Loading pretrained Model")
26
+ model.load_state_dict(torch.load(arch_path,map_location='cpu')['state_dict'],strict = True)
27
+ model.eval()
28
+ return model
29
+
30
+
31
+ def get_transform(args):
32
+ return transforms.Compose([transforms.Resize([args.PARAMETERS.img_resize,args.PARAMETERS.img_resize]),
33
+ transforms.CenterCrop([args.PARAMETERS.img_crop,args.PARAMETERS.img_crop]),
34
+ transforms.ToTensor()])
35
+
36
+
37
+ class FeatureExtractor(nn.Module):
38
+ def __init__(self, model: nn.Module, layers: Iterable[str]):
39
+ super().__init__()
40
+ self.model = model
41
+ self.layers = layers
42
+ self._features = {layer: torch.empty(0) for layer in layers}
43
+
44
+ for layer_id in layers:
45
+ layer = dict([*self.model.named_modules()])[layer_id]
46
+ layer.register_forward_hook(self.save_outputs_hook(layer_id))
47
+
48
+ def save_outputs_hook(self, layer_id: str) -> Callable:
49
+ def fn(_, __, output):
50
+ self._features[layer_id] = output
51
+ return fn
52
+
53
+ def forward(self, x: Tensor) -> Dict[str, Tensor]:
54
+ _ = self.model(x)
55
+ return self._features
56
+
57
+
58
+ def _load_dataset(args):
59
+ if args.PARAMETERS.metric == 'L2':
60
+ faiss_metric = faiss.METRIC_L2
61
+ dataset = load_dataset(args.PARAMETERS.dataset,split = 'train')
62
+ dataset = dataset.add_faiss_index(column=args.ROBUST.embedding_col,metric_type = faiss_metric)
63
+ dataset = dataset.add_faiss_index(column=args.NONROBUST.embedding_col,metric_type = faiss_metric)
64
+ return dataset
65
+
66
+
67
+ args = OmegaConf.load("configs/resnet.yaml")
68
+ wiki_dataset = _load_dataset(args)
69
+ TRANSFORMS = get_transform(args)
70
+ robust_model = get_model(args,args.ROBUST.arch,args.ROBUST.load_from,args.ROBUST.arch_path)
71
+ non_robust_model = get_model(args,args.NONROBUST.arch,args.NONROBUST.load_from,args.NONROBUST.arch_path)
72
+ fe_robust_model = FeatureExtractor(robust_model,layers = [args.ROBUST.layer])
73
+ fe_nonrobust_model = FeatureExtractor(non_robust_model,layers = [args.NONROBUST.layer])
74
+
75
+
76
+ # +
77
+ def retrieval_fn(image,radio):
78
+ try:
79
+ image = Image.fromarray(image)
80
+ except:
81
+ pass
82
+ image = TRANSFORMS(image).unsqueeze(0)
83
+ image = image.to(args.PARAMETERS.device)
84
+
85
+ if radio == 'robust':
86
+ emb = fe_robust_model(image)[args.ROBUST.layer]
87
+ emb = emb.view(1,-1).detach().cpu().numpy()
88
+ scores, retrieved_examples = wiki_dataset.get_nearest_examples(index_name = args.ROBUST.embedding_col,
89
+ query = emb,
90
+ k = 3)
91
+ elif radio == 'standard':
92
+ emb = fe_nonrobust_model(image)[args.NONROBUST.layer]
93
+ emb = emb.view(1,-1).detach().cpu().numpy()
94
+ scores, retrieved_examples = wiki_dataset.get_nearest_examples(index_name = args.NONROBUST.embedding_col,
95
+ query = emb,
96
+ k=3)
97
+ return scores,retrieved_examples
98
+
99
+ def gradio_fn(image,radio):
100
+ scores,retrieved_examples = retrieval_fn(image,radio)
101
+ m = []
102
+ for description,image,score in zip(retrieved_examples['description'],
103
+ retrieved_examples['image'],
104
+ scores):
105
+ m.append(description)
106
+ m.append(image)
107
+ return m
108
+
109
+
110
+ # -
111
+
112
+ if __name__ == '__main__':
113
+ demo = gr.Blocks()
114
+ with demo:
115
+ gr.Markdown("# Robust vs Standard Image Retrieval")
116
+ with gr.Tabs():
117
+ with gr.TabItem("Upload your Image"):
118
+ with gr.Row():
119
+ with gr.Column():
120
+ with gr.Row():
121
+ image_input = gr.Image(label="Input Image")
122
+ with gr.Row():
123
+ radio_button = gr.Radio(["robust","standard"],
124
+ value = "robust",
125
+ label = "OD Model")
126
+ with gr.Row():
127
+ calculate_button = gr.Button("Compute")
128
+ with gr.Column():
129
+ textbox1 = gr.Textbox(label = "Artist / Title / Style / Genre / Date")
130
+ output_image1 = gr.Image(label="1st Best match")
131
+ textbox2 = gr.Textbox(label = "Artist / Title / Style / Genre / Date")
132
+ output_image2 = gr.Image(label="2nd Best match")
133
+ textbox3 = gr.Textbox(label = "Artist / Title / Style / Genre / Date")
134
+ output_image3 = gr.Image(label="3rd Best match")
135
+
136
+ calculate_button.click(fn = gradio_fn,
137
+ inputs = [image_input,radio_button],
138
+ outputs = [textbox1,output_image1,textbox2,output_image2,textbox3,output_image3])
139
+ demo.launch(share = False,debug = True)
app.py ADDED
@@ -0,0 +1,139 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from datasets import load_dataset
2
+ from torchvision import transforms
3
+ import torch
4
+ from timm import create_model
5
+ from omegaconf import OmegaConf
6
+ import faiss
7
+ import pickle
8
+ import gradio as gr
9
+ import os
10
+ import joblib
11
+ import torch.nn as nn
12
+ from typing import Dict, Iterable, Callable
13
+ from torch import Tensor
14
+ import torchvision
15
+ from PIL import Image
16
+
17
+
18
+ def get_model(args,arch,load_from,arch_path):
19
+ if load_from == 'timm':
20
+ model = create_model(arch,pretrained = True).to(args.PARAMETERS.device)
21
+ elif load_from == 'torchvision':
22
+ if arch == 'resnet50':
23
+ model = torchvision.models.resnet50(pretrained=False)
24
+ if len(arch_path)>0:
25
+ print("Loading pretrained Model")
26
+ model.load_state_dict(torch.load(arch_path,map_location='cpu')['state_dict'],strict = True)
27
+ model.eval()
28
+ return model
29
+
30
+
31
+ def get_transform(args):
32
+ return transforms.Compose([transforms.Resize([args.PARAMETERS.img_resize,args.PARAMETERS.img_resize]),
33
+ transforms.CenterCrop([args.PARAMETERS.img_crop,args.PARAMETERS.img_crop]),
34
+ transforms.ToTensor()])
35
+
36
+
37
+ class FeatureExtractor(nn.Module):
38
+ def __init__(self, model: nn.Module, layers: Iterable[str]):
39
+ super().__init__()
40
+ self.model = model
41
+ self.layers = layers
42
+ self._features = {layer: torch.empty(0) for layer in layers}
43
+
44
+ for layer_id in layers:
45
+ layer = dict([*self.model.named_modules()])[layer_id]
46
+ layer.register_forward_hook(self.save_outputs_hook(layer_id))
47
+
48
+ def save_outputs_hook(self, layer_id: str) -> Callable:
49
+ def fn(_, __, output):
50
+ self._features[layer_id] = output
51
+ return fn
52
+
53
+ def forward(self, x: Tensor) -> Dict[str, Tensor]:
54
+ _ = self.model(x)
55
+ return self._features
56
+
57
+
58
+ def _load_dataset(args):
59
+ if args.PARAMETERS.metric == 'L2':
60
+ faiss_metric = faiss.METRIC_L2
61
+ dataset = load_dataset(args.PARAMETERS.dataset,split = 'train')
62
+ dataset = dataset.add_faiss_index(column=args.ROBUST.embedding_col,metric_type = faiss_metric)
63
+ dataset = dataset.add_faiss_index(column=args.NONROBUST.embedding_col,metric_type = faiss_metric)
64
+ return dataset
65
+
66
+
67
+ args = OmegaConf.load("configs/resnet.yaml")
68
+ wiki_dataset = _load_dataset(args)
69
+ TRANSFORMS = get_transform(args)
70
+ robust_model = get_model(args,args.ROBUST.arch,args.ROBUST.load_from,args.ROBUST.arch_path)
71
+ non_robust_model = get_model(args,args.NONROBUST.arch,args.NONROBUST.load_from,args.NONROBUST.arch_path)
72
+ fe_robust_model = FeatureExtractor(robust_model,layers = [args.ROBUST.layer])
73
+ fe_nonrobust_model = FeatureExtractor(non_robust_model,layers = [args.NONROBUST.layer])
74
+
75
+
76
+ # +
77
+ def retrieval_fn(image,radio):
78
+ try:
79
+ image = Image.fromarray(image)
80
+ except:
81
+ pass
82
+ image = TRANSFORMS(image).unsqueeze(0)
83
+ image = image.to(args.PARAMETERS.device)
84
+
85
+ if radio == 'robust':
86
+ emb = fe_robust_model(image)[args.ROBUST.layer]
87
+ emb = emb.view(1,-1).detach().cpu().numpy()
88
+ scores, retrieved_examples = wiki_dataset.get_nearest_examples(index_name = args.ROBUST.embedding_col,
89
+ query = emb,
90
+ k = 3)
91
+ elif radio == 'standard':
92
+ emb = fe_nonrobust_model(image)[args.NONROBUST.layer]
93
+ emb = emb.view(1,-1).detach().cpu().numpy()
94
+ scores, retrieved_examples = wiki_dataset.get_nearest_examples(index_name = args.NONROBUST.embedding_col,
95
+ query = emb,
96
+ k=3)
97
+ return scores,retrieved_examples
98
+
99
+ def gradio_fn(image,radio):
100
+ scores,retrieved_examples = retrieval_fn(image,radio)
101
+ m = []
102
+ for description,image,score in zip(retrieved_examples['description'],
103
+ retrieved_examples['image'],
104
+ scores):
105
+ m.append(description)
106
+ m.append(image)
107
+ return m
108
+
109
+
110
+ # -
111
+
112
+ if __name__ == '__main__':
113
+ demo = gr.Blocks()
114
+ with demo:
115
+ gr.Markdown("# Robust vs Standard Image Retrieval")
116
+ with gr.Tabs():
117
+ with gr.TabItem("Upload your Image"):
118
+ with gr.Row():
119
+ with gr.Column():
120
+ with gr.Row():
121
+ image_input = gr.Image(label="Input Image")
122
+ with gr.Row():
123
+ radio_button = gr.Radio(["robust","standard"],
124
+ value = "robust",
125
+ label = "OD Model")
126
+ with gr.Row():
127
+ calculate_button = gr.Button("Compute")
128
+ with gr.Column():
129
+ textbox1 = gr.Textbox(label = "Artist / Title / Style / Genre / Date")
130
+ output_image1 = gr.Image(label="1st Best match")
131
+ textbox2 = gr.Textbox(label = "Artist / Title / Style / Genre / Date")
132
+ output_image2 = gr.Image(label="2nd Best match")
133
+ textbox3 = gr.Textbox(label = "Artist / Title / Style / Genre / Date")
134
+ output_image3 = gr.Image(label="3rd Best match")
135
+
136
+ calculate_button.click(fn = gradio_fn,
137
+ inputs = [image_input,radio_button],
138
+ outputs = [textbox1,output_image1,textbox2,output_image2,textbox3,output_image3])
139
+ demo.launch(share = False,debug = True)
configs/resnet.yaml ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ PARAMETERS:
2
+ img_resize: 256
3
+ img_crop: 256
4
+ num_workers: 72
5
+ device: "cpu"
6
+ dataset: "Artificio/WikiArt_mini_demos"
7
+ metric: "L2"
8
+
9
+ ROBUST:
10
+ arch: "resnet50"
11
+ arch_path: "models/robust_resnet50.pt"
12
+ load_from: "torchvision"
13
+ layer: "avgpool"
14
+ embedding_col: "resnet50_robust_features_2048"
15
+ NONROBUST:
16
+ arch: "resnet50"
17
+ arch_path : ""
18
+ load_from: "timm"
19
+ layer: "global_pool"
20
+ embedding_col: "resnet50_non_robust_features_2048"
models/robust_resnet50.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:9ee26195016452801a20de92d7d8d26f42249b5074301a4aeff342eb565b3c47
3
+ size 102544897