edadaltocg commited on
Commit
df37aa3
1 Parent(s): db83f58

update images

Browse files
app.py CHANGED
@@ -3,22 +3,21 @@ Gradio demo of image classification with OOD detection.
3
 
4
  If the image example is probably OOD, the model will abstain from the prediction.
5
  """
6
- import os
7
- import pickle
8
  import json
 
 
9
  from glob import glob
10
 
11
  import gradio as gr
12
- from gradio.components import Image, Label, JSON
13
  import numpy as np
14
- import torch
15
  import timm
 
 
 
16
  from timm.data import resolve_data_config
17
  from timm.data.transforms_factory import create_transform
18
  from torchvision.models.feature_extraction import create_feature_extractor, get_graph_node_names
19
 
20
- import logging
21
-
22
  _logger = logging.getLogger(__name__)
23
 
24
  device = "cuda" if torch.cuda.is_available() else "cpu"
@@ -26,7 +25,7 @@ TOPK = 3
26
 
27
  # load model
28
  print("Loading model...")
29
- model = timm.create_model("resnet50", pretrained=True)
30
  model.to(device)
31
  model.eval()
32
 
@@ -34,28 +33,25 @@ model.eval()
34
  idx2label = json.loads(open("ilsvrc2012.json").read())
35
  idx2label = {int(k): v for k, v in idx2label.items()}
36
  print(idx2label)
 
37
 
38
  # transformation
39
  config = resolve_data_config({}, model=model)
40
  config["is_training"] = False
41
  transform = create_transform(**config)
42
 
43
- # print features names
44
- print(get_graph_node_names(model)[0])
45
-
46
- # load train scores
47
  penultimate_features_key = "global_pool.flatten"
48
  logits_key = "fc"
49
  features_names = [penultimate_features_key, logits_key]
50
 
51
- # create feature extractor
52
  feature_extractor = create_feature_extractor(model, features_names)
53
 
54
- # OOD dtector thresholds
 
55
  msp_threshold = 0.3796
56
  energy_threshold = 0.3781
57
-
58
- ## unpickle detectors
59
 
60
 
61
  def mahalanobis_penult(features):
@@ -72,9 +68,18 @@ def energy(logits):
72
  return torch.logsumexp(logits, dim=1).item()
73
 
74
 
 
 
 
 
 
 
 
 
75
  def predict(image):
76
  # forward pass
77
  inputs = transform(image).unsqueeze(0)
 
78
  with torch.no_grad():
79
  features = feature_extractor(inputs)
80
 
@@ -86,13 +91,16 @@ def predict(image):
86
 
87
  result = {idx2label[i.item()]: v.item() for i, v in zip(class_idxs.squeeze(), softmax.squeeze())}
88
  # OOD
89
- msp_score = msp(features[logits_key])
90
- energy_score = energy(features[logits_key])
 
91
  ood_scores = {
92
- "msp": msp_score,
93
- "msp_is_ood": msp_score < msp_threshold,
94
- "energy": energy_score,
95
- "energy_is_ood": energy_score < energy_threshold,
 
 
96
  }
97
  _logger.info(ood_scores)
98
  return result, ood_scores
@@ -100,9 +108,9 @@ def predict(image):
100
 
101
  def main():
102
  # image examples for demo shuffled
103
- examples = glob("images/imagenet/*.jpg") + glob("images/ood/*.jpg")
104
  np.random.seed(42)
105
- np.random.shuffle(examples)
106
 
107
  # gradio interface
108
  interface = gr.Interface(
@@ -117,11 +125,24 @@ def main():
117
  allow_flagging="never",
118
  theme="default",
119
  title="OOD Detection 🧐",
120
- description="Out-of-distribution (OOD) detection is an essential safety measure for machine learning models. This app demonstrates how these methods can be useful. They try to determine wether we can trust the predictions of a ResNet-50 model trained on ImageNet-1K. Enjoy the demo!",
121
- )
122
- interface.launch(
123
- server_port=7860,
 
 
 
 
 
 
 
 
 
 
 
 
124
  )
 
125
  interface.close()
126
 
127
 
 
3
 
4
  If the image example is probably OOD, the model will abstain from the prediction.
5
  """
 
 
6
  import json
7
+ import logging
8
+ import pickle
9
  from glob import glob
10
 
11
  import gradio as gr
 
12
  import numpy as np
 
13
  import timm
14
+ import torch
15
+ import torch.nn.functional as F
16
+ from gradio.components import JSON, Image, Label
17
  from timm.data import resolve_data_config
18
  from timm.data.transforms_factory import create_transform
19
  from torchvision.models.feature_extraction import create_feature_extractor, get_graph_node_names
20
 
 
 
21
  _logger = logging.getLogger(__name__)
22
 
23
  device = "cuda" if torch.cuda.is_available() else "cpu"
 
25
 
26
  # load model
27
  print("Loading model...")
28
+ model = timm.create_model("resnet50.tv2_in1k", pretrained=True, checkpoint_path="resnet50.tv2_in1k.bin")
29
  model.to(device)
30
  model.eval()
31
 
 
33
  idx2label = json.loads(open("ilsvrc2012.json").read())
34
  idx2label = {int(k): v for k, v in idx2label.items()}
35
  print(idx2label)
36
+ print(idx2label.values())
37
 
38
  # transformation
39
  config = resolve_data_config({}, model=model)
40
  config["is_training"] = False
41
  transform = create_transform(**config)
42
 
43
+ # create feature extractor
 
 
 
44
  penultimate_features_key = "global_pool.flatten"
45
  logits_key = "fc"
46
  features_names = [penultimate_features_key, logits_key]
47
 
 
48
  feature_extractor = create_feature_extractor(model, features_names)
49
 
50
+ centroids = torch.from_numpy(pickle.load(open("centroids_resnet50.tv2_in1k_igeood_logits.pkl", "rb"))).to(device)
51
+ # OOD detector thresholds
52
  msp_threshold = 0.3796
53
  energy_threshold = 0.3781
54
+ igeood_threshold = 2.4984
 
55
 
56
 
57
  def mahalanobis_penult(features):
 
68
  return torch.logsumexp(logits, dim=1).item()
69
 
70
 
71
+ def igeoodlogits_vec(logits, temperature, centroids, epsilon=1e-12):
72
+ logits = torch.sqrt(F.softmax(logits / temperature, dim=1))
73
+ centroids = torch.sqrt(F.softmax(centroids / temperature, dim=1))
74
+ mult = logits @ centroids.T
75
+ stack = 2 * torch.acos(torch.clamp(mult, -1 + epsilon, 1 - epsilon))
76
+ return stack.mean(dim=1).item()
77
+
78
+
79
  def predict(image):
80
  # forward pass
81
  inputs = transform(image).unsqueeze(0)
82
+ inputs = inputs.to(device)
83
  with torch.no_grad():
84
  features = feature_extractor(inputs)
85
 
 
91
 
92
  result = {idx2label[i.item()]: v.item() for i, v in zip(class_idxs.squeeze(), softmax.squeeze())}
93
  # OOD
94
+ msp_score = round(msp(features[logits_key]), 4)
95
+ energy_score = round(energy(features[logits_key]), 4)
96
+ igeood_scores = round(igeoodlogits_vec(features[logits_key], 1, centroids), 4)
97
  ood_scores = {
98
+ "MSP": msp_score,
99
+ "MSP, is the input OOD?": msp_score < msp_threshold,
100
+ "Energy": energy_score,
101
+ "Energy, is the input OOD?": energy_score < energy_threshold,
102
+ "Igeood": igeood_scores,
103
+ "Igeood, is the input OOD?": igeood_scores < igeood_threshold,
104
  }
105
  _logger.info(ood_scores)
106
  return result, ood_scores
 
108
 
109
  def main():
110
  # image examples for demo shuffled
111
+ examples = glob("images/imagenet/*") + glob("images/ood/*")
112
  np.random.seed(42)
113
+ # np.random.shuffle(examples)
114
 
115
  # gradio interface
116
  interface = gr.Interface(
 
125
  allow_flagging="never",
126
  theme="default",
127
  title="OOD Detection 🧐",
128
+ description=(
129
+ "Out-of-distribution (OOD) detection is an essential safety measure for machine learning models. "
130
+ "The objective of an OOD detector is to determine wether the input sample comes from the distribution known by the AI model. "
131
+ "For instance, an input that does not belong to any of the known classes or is from a different domain should be flagged by the detector.\n"
132
+ "In this demo we will display the decision of three OOD detectors on a ResNet-50 model trained to classify on the ImageNet-1K dataset (top-1 accuracy 80%)."
133
+ "This model can classify among 1000 classes from several categories, including `animals`, `vehicles`, `clothing`, `instruments`, `plants`, etc. "
134
+ "For the complete hierarchy of classes, please check the website https://observablehq.com/@mbostock/imagenet-hierarchy. "
135
+ "\n\n"
136
+ "## Instructions:\n"
137
+ "1. Upload an image of your choice or select one from the examples bar.\n"
138
+ "2. The model will predict the top 3 most likely classes for the image.\n"
139
+ "3. The OOD detectors will output their scores and decision on the image. The smaller the score, the least confident the detector is on the sample being in-distribution.\n"
140
+ "4. If the image is OOD, the model will abstain from the prediction and flag it to the practicioner.\n"
141
+ "\n\n\nEnjoy the demo!"
142
+ ),
143
+ cache_examples=True,
144
  )
145
+ interface.launch(server_port=7860)
146
  interface.close()
147
 
148
 
images/imagenet/n02828884_603_bench.jpg DELETED
Binary file (148 kB)
 
images/imagenet/n02834778_3678_bicycle.jpg DELETED
Binary file (75.3 kB)
 
images/imagenet/n02880940_17711_bowl.jpg DELETED
Binary file (68.9 kB)
 
images/imagenet/n03062245_2005_cocktail_shaker.jpg DELETED
Binary file (57.7 kB)
 
images/imagenet/n03495258_9079_harp.jpg DELETED
Binary file (112 kB)
 
images/ood/Rademacher_025_Rademacher_02897.jpg ADDED
images/ood/art_2.jpg ADDED
images/ood/bumpy_0140.jpg DELETED
Binary file (105 kB)
 
images/ood/door_022_00033.jpg ADDED
images/ood/fdb9d2ac3f37c0c80baa7f91775e58ce.jpg DELETED
Binary file (577 kB)
 
images/ood/fed8bd31654ee16a9cd83c8de72ddb5b.jpg DELETED
Binary file (752 kB)
 
images/ood/ff7f83dfb2485306b62bf64726f4f932.jpg DELETED
Binary file (191 kB)
 
images/ood/ffd5b90b142ebcb46cffc96314e6bcd3.jpg DELETED
Binary file (206 kB)
 
images/ood/fireworks_001_0001.png ADDED
images/ood/i_ice_floe_00002019.jpg DELETED
Binary file (58.3 kB)
 
images/ood/i_igloo_00002495.jpg DELETED
Binary file (60.6 kB)
 
images/ood/knitted_0141.jpg DELETED
Binary file (71.9 kB)
 
images/ood/pyramid_008_image_0011.jpg ADDED
images/ood/scissors_040_scissors_0085_pixabay.jpg ADDED
images/ood/striped_0063.jpg DELETED
Binary file (22.3 kB)
 
images/ood/sun_awovauomdhnolaul.jpg DELETED
Binary file (33.2 kB)
 
images/ood/sun_bzrmbfcxyebbxuqu.jpg DELETED
Binary file (361 kB)
 
images/ood/sun_bzuroamlnffhyuqn.jpg DELETED
Binary file (170 kB)
 
images/ood/toy_2.jpg ADDED
images/ood/w_waterfall_00004924.jpg DELETED
Binary file (69.3 kB)
 
images/ood/w_wheat_field_00004628.jpg DELETED
Binary file (105 kB)
 
images/ood/watermelon_0.9992305.JPEG ADDED