edadaltocg commited on
Commit
301b1c6
1 Parent(s): 3fceb24

update app

Browse files
.gitignore CHANGED
@@ -138,4 +138,5 @@ dmypy.json
138
  cython_debug/
139
 
140
  .DS_Store
141
- .vscode
 
 
138
  cython_debug/
139
 
140
  .DS_Store
141
+ .vscode
142
+ data/
README.md CHANGED
@@ -16,7 +16,7 @@ Out-of-distribution (OOD) detection is an essential safety measure for machine l
16
 
17
  This demo is [online](https://huggingface.co/spaces/edadaltocg/ood-detection) at `https://huggingface.co/spaces/edadaltocg/ood-detection`
18
 
19
- ## Running Gradio app locally:
20
 
21
  1. Install dependencies:
22
 
@@ -31,9 +31,3 @@ python app.py
31
  ```
32
 
33
  3. Open the app in your browser at `http://localhost:7860`.
34
-
35
- ## Methods implemented
36
-
37
- - [ ] [Mahalanobis Distance](https://arxiv.org/abs/1807.03888)
38
- - [x] [Maximum Softmax Probability](https://arxiv.org/abs/1610.02136)
39
- - [x] [Energy Based Out-of-Distribution Detection](https://arxiv.org/abs/2010.03759)
 
16
 
17
  This demo is [online](https://huggingface.co/spaces/edadaltocg/ood-detection) at `https://huggingface.co/spaces/edadaltocg/ood-detection`
18
 
19
+ ## Running Gradio app locally
20
 
21
  1. Install dependencies:
22
 
 
31
  ```
32
 
33
  3. Open the app in your browser at `http://localhost:7860`.
 
 
 
 
 
 
app.py CHANGED
@@ -25,7 +25,7 @@ TOPK = 3
25
 
26
  # load model
27
  print("Loading model...")
28
- model = timm.create_model("resnet50.tv2_in1k", pretrained=True, checkpoint_path="resnet50.tv2_in1k.bin")
29
  model.to(device)
30
  model.eval()
31
 
 
25
 
26
  # load model
27
  print("Loading model...")
28
+ model = timm.create_model("resnet50.tv2_in1k", pretrained=True)
29
  model.to(device)
30
  model.eval()
31
 
centroids_resnet50.tv2_in1k_igeood_logits.pkl ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:f8079b4fc02b6542210d147d98d08b6220372534a18ba7ef9e844b17ab0a1d7e
3
+ size 4000163
imagenet_ood.py ADDED
@@ -0,0 +1,132 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ import os
3
+ from typing import Callable, Optional
4
+
5
+ from torchvision.datasets import ImageFolder
6
+ from torchvision.datasets.utils import check_integrity, download_and_extract_archive, verify_str_arg
7
+
8
+ _logger = logging.getLogger(__name__)
9
+
10
+
11
+ class ImageNetA(ImageFolder):
12
+ """ImageNetA dataset.
13
+
14
+ - Paper: [https://arxiv.org/abs/1907.07174](https://arxiv.org/abs/1907.07174).
15
+ """
16
+
17
+ base_folder = "imagenet-a"
18
+ url = "https://people.eecs.berkeley.edu/~hendrycks/imagenet-a.tar"
19
+ filename = "imagenet-a.tar"
20
+ tgz_md5 = "c3e55429088dc681f30d81f4726b6595"
21
+
22
+ def __init__(self, root: str, split=None, transform: Optional[Callable] = None, download: bool = False, **kwargs):
23
+ self.root = root
24
+
25
+ if download:
26
+ self.download()
27
+
28
+ if not self._check_integrity():
29
+ raise RuntimeError("Dataset not found or corrupted." + " You can use download=True to download it")
30
+
31
+ super().__init__(root=os.path.join(root, self.base_folder), transform=transform, **kwargs)
32
+
33
+ def _check_exists(self) -> bool:
34
+ return os.path.exists(os.path.join(self.root, self.base_folder))
35
+
36
+ def _check_integrity(self) -> bool:
37
+ return check_integrity(os.path.join(self.root, self.filename), self.tgz_md5)
38
+
39
+ def download(self) -> None:
40
+ if self._check_integrity() and self._check_exists():
41
+ _logger.debug("Files already downloaded and verified")
42
+ return
43
+ download_and_extract_archive(self.url, self.root, filename=self.filename, md5=self.tgz_md5)
44
+
45
+
46
+ class ImageNetO(ImageNetA):
47
+ """ImageNetO datasets.
48
+
49
+ Contains unknown classes to ImageNet-1k.
50
+
51
+
52
+ - Paper: [https://arxiv.org/abs/1907.07174](https://arxiv.org/abs/1907.07174)
53
+ """
54
+
55
+ base_folder = "imagenet-o"
56
+ url = "https://people.eecs.berkeley.edu/~hendrycks/imagenet-o.tar"
57
+ filename = "imagenet-o.tar"
58
+ tgz_md5 = "86bd7a50c1c4074fb18fc5f219d6d50b"
59
+
60
+
61
+ class ImageNetR(ImageNetA):
62
+ """ImageNet-R(endition) dataset.
63
+
64
+ Contains art, cartoons, deviantart, graffiti, embroidery, graphics, origami, paintings,
65
+ patterns, plastic objects,plush objects, sculptures, sketches, tattoos, toys,
66
+ and video game renditions of ImageNet-1k classes.
67
+
68
+ - Paper: [https://arxiv.org/abs/2006.16241](https://arxiv.org/abs/2006.16241)
69
+ """
70
+
71
+ base_folder = "imagenet-r"
72
+ url = "https://people.eecs.berkeley.edu/~hendrycks/imagenet-r.tar"
73
+ filename = "imagenet-r.tar"
74
+ tgz_md5 = "a61312130a589d0ca1a8fca1f2bd3337"
75
+
76
+
77
+ class NINCOFull(ImageFolder):
78
+ """`NINCO` Dataset subset.
79
+
80
+ Args:
81
+ root (string): Root directory of dataset where directory
82
+ exists or will be saved to if download is set to True.
83
+ split (string, optional): The dataset split, not used.
84
+ transform (callable, optional): A function/transform that takes in an PIL image
85
+ and returns a transformed version. E.g, `transforms.RandomCrop`.
86
+ download (bool, optional): If true, downloads the dataset from the internet and
87
+ puts it in root directory. If dataset is already downloaded, it is not
88
+ downloaded again.
89
+ **kwargs: Additional arguments passed to :class:`~torchvision.datasets.ImageFolder`.
90
+ """
91
+
92
+ PAPER_URL = "https://arxiv.org/pdf/2306.00826.pdf"
93
+ base_folder = "ninco"
94
+ filename = "NINCO_all.tar.gz"
95
+ file_md5 = "b9ffae324363cd900a81ce3c367cd834"
96
+ url = "https://zenodo.org/record/8013288/files/NINCO_all.tar.gz"
97
+ # size: 15393
98
+
99
+ def __init__(
100
+ self, root: str, split=None, transform: Optional[Callable] = None, download: bool = False, **kwargs
101
+ ) -> None:
102
+ self.root = os.path.expanduser(root)
103
+ self.dataset_folder = os.path.join(self.root, self.base_folder)
104
+ self.archive = os.path.join(self.root, self.filename)
105
+
106
+ if download:
107
+ self.download()
108
+
109
+ if not self._check_integrity():
110
+ raise RuntimeError("Dataset not found or corrupted." + " You can use download=True to download it")
111
+
112
+ super().__init__(self.dataset_folder, transform=transform, **kwargs)
113
+
114
+ def _check_integrity(self) -> bool:
115
+ return check_integrity(self.archive, self.file_md5)
116
+
117
+ def _check_exists(self) -> bool:
118
+ return os.path.exists(self.dataset_folder)
119
+
120
+ def download(self) -> None:
121
+ if self._check_integrity() and self._check_exists():
122
+ return
123
+ download_and_extract_archive(
124
+ self.url, download_root=self.root, extract_root=self.dataset_folder, md5=self.file_md5
125
+ )
126
+
127
+
128
+ if __name__ == "__main__":
129
+ ImageNetR(root="data", download=True)
130
+ ImageNetO(root="data", download=True)
131
+ ImageNetA(root="data", download=True)
132
+ NINCOFull(root="data", download=True)