HoomKh commited on
Commit
e5461d8
·
verified ·
1 Parent(s): ca4180d
Dockerfile ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -----------------------------------------------------------------------------
2
+ # A sample Dockerfile to help you replicate our test environment
3
+ # -----------------------------------------------------------------------------
4
+
5
+ FROM pytorch/pytorch:2.4.1-cuda12.4-cudnn9-runtime
6
+ WORKDIR /app
7
+ COPY . .
8
+
9
+ # Install your python and apt requirements
10
+ RUN pip install -r requirements.txt
11
+ RUN apt-get update && apt-get install $(cat apt_requirements.txt) -y
12
+ RUN chmod +x run.sh
13
+
14
+ CMD ["python3", "runner.py"]
datasets/all_classes_dataset.py ADDED
@@ -0,0 +1,158 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # datasets/all_classes_dataset.py
2
+
3
+ import os
4
+ from enum import Enum
5
+
6
+ import PIL
7
+ import torch
8
+ from torch.utils.data import Dataset
9
+ from torchvision import transforms
10
+
11
+ IMAGENET_MEAN = [0.485, 0.456, 0.406]
12
+ IMAGENET_STD = [0.229, 0.224, 0.225]
13
+
14
+
15
+ class DatasetSplit(Enum):
16
+ TRAIN = "train"
17
+ VAL = "val"
18
+ TEST = "test"
19
+
20
+
21
+ class AllClassesDataset(Dataset):
22
+ def __init__(
23
+ self,
24
+ source,
25
+ input_size=518,
26
+ output_size=224,
27
+ split=DatasetSplit.TEST,
28
+ external_transform=None,
29
+ **kwargs,
30
+ ):
31
+ """
32
+ Initialize the dataset to include all classes.
33
+
34
+ Args:
35
+ source (str): Path to the root data directory.
36
+ input_size (int): Input image size for transformations.
37
+ output_size (int): Output mask size.
38
+ split (DatasetSplit): Dataset split to use (TRAIN, VAL, TEST).
39
+ external_transform (callable, optional): External image transformations.
40
+ **kwargs: Additional keyword arguments.
41
+ """
42
+ super().__init__()
43
+ self.source = source
44
+ self.split = split
45
+ self.classnames_to_use = self.get_all_class_names()
46
+
47
+ self.imgpaths_per_class, self.data_to_iterate = self.get_image_data()
48
+
49
+ if external_transform is None:
50
+ self.transform_img = transforms.Compose([
51
+ transforms.Resize((input_size, input_size)),
52
+ # transforms.CenterCrop(input_size),
53
+ transforms.ToTensor(),
54
+ transforms.Normalize(mean=IMAGENET_MEAN, std=IMAGENET_STD),
55
+ ])
56
+ else:
57
+ self.transform_img = external_transform
58
+
59
+ self.transform_mask = transforms.Compose([
60
+ transforms.Resize((output_size, output_size)),
61
+ # transforms.CenterCrop(output_size),
62
+ transforms.ToTensor(),
63
+ ])
64
+ self.output_shape = (1, output_size, output_size)
65
+
66
+ def get_all_class_names(self):
67
+ """
68
+ Retrieve all class names (subdirectories) from the source directory.
69
+
70
+ Returns:
71
+ list: List of class names.
72
+ """
73
+ all_items = os.listdir(self.source)
74
+ classnames = [
75
+ item for item in all_items
76
+ if os.path.isdir(os.path.join(self.source, item))
77
+ ]
78
+ return classnames
79
+
80
+ def get_image_data(self):
81
+ """
82
+ Collect image paths and corresponding mask paths for all classes.
83
+
84
+ Returns:
85
+ tuple: (imgpaths_per_class, data_to_iterate)
86
+ """
87
+ imgpaths_per_class = {}
88
+ maskpaths_per_class = {}
89
+
90
+ for classname in self.classnames_to_use:
91
+ classpath = os.path.join(self.source, classname, self.split.value)
92
+ maskpath = os.path.join(self.source, classname, "ground_truth")
93
+ anomaly_types = os.listdir(classpath)
94
+
95
+ imgpaths_per_class[classname] = {}
96
+ maskpaths_per_class[classname] = {}
97
+
98
+ for anomaly in anomaly_types:
99
+ anomaly_path = os.path.join(classpath, anomaly)
100
+ anomaly_files = sorted(os.listdir(anomaly_path))
101
+ imgpaths_per_class[classname][anomaly] = [
102
+ os.path.join(anomaly_path, x) for x in anomaly_files
103
+ ]
104
+
105
+ if self.split == DatasetSplit.TEST and anomaly != "good":
106
+ anomaly_mask_path = os.path.join(maskpath, anomaly)
107
+ if os.path.exists(anomaly_mask_path):
108
+ anomaly_mask_files = sorted(os.listdir(anomaly_mask_path))
109
+ maskpaths_per_class[classname][anomaly] = [
110
+ os.path.join(anomaly_mask_path, x) for x in anomaly_mask_files
111
+ ]
112
+ else:
113
+ # If mask path does not exist, set to None
114
+ maskpaths_per_class[classname][anomaly] = [None] * len(anomaly_files)
115
+ else:
116
+ maskpaths_per_class[classname]["good"] = [None] * len(anomaly_files)
117
+
118
+ data_to_iterate = []
119
+ for classname in sorted(imgpaths_per_class.keys()):
120
+ for anomaly in sorted(imgpaths_per_class[classname].keys()):
121
+ for i, image_path in enumerate(imgpaths_per_class[classname][anomaly]):
122
+ data_tuple = [classname, anomaly, image_path]
123
+ if self.split == DatasetSplit.TEST and anomaly != "good":
124
+ mask_path = maskpaths_per_class[classname][anomaly][i]
125
+ data_tuple.append(mask_path)
126
+ else:
127
+ data_tuple.append(None)
128
+ data_to_iterate.append(data_tuple)
129
+
130
+ return imgpaths_per_class, data_to_iterate
131
+
132
+ def __getitem__(self, idx):
133
+ classname, anomaly, image_path, mask_path = self.data_to_iterate[idx]
134
+ try:
135
+ image = PIL.Image.open(image_path).convert("RGB")
136
+ except Exception as e:
137
+ # Return a black image or handle as per your requirement
138
+ image = PIL.Image.new("RGB", (self.transform_img.transforms[0].size, self.transform_img.transforms[0].size), (0, 0, 0))
139
+ image = self.transform_img(image)
140
+
141
+ if self.split == DatasetSplit.TEST and mask_path is not None:
142
+ try:
143
+ mask = PIL.Image.open(mask_path).convert("L")
144
+ mask = self.transform_mask(mask) > 0
145
+ except Exception as e:
146
+ mask = torch.zeros([*self.output_shape])
147
+ else:
148
+ mask = torch.zeros([*self.output_shape])
149
+
150
+ return {
151
+ "image": image, # Tensor: [3, H, W]
152
+ "mask": mask, # Tensor: [1, 17, 17]
153
+ "is_anomaly": int(anomaly != "good"),
154
+ "image_path": image_path,
155
+ }
156
+
157
+ def __len__(self):
158
+ return len(self.data_to_iterate)
datasets/mvec.py ADDED
@@ -0,0 +1,301 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from torchvision import transforms
2
+ from .perlin import perlin_mask
3
+ from enum import Enum
4
+
5
+ import numpy as np
6
+ import pandas as pd
7
+ import logging
8
+
9
+ LOGGER = logging.getLogger(__name__)
10
+ import PIL
11
+ import torch
12
+ import os
13
+ import glob
14
+
15
+ _CLASSNAMES = [
16
+ "carpet",
17
+ "grid",
18
+ "leather",
19
+ "tile",
20
+ "wood",
21
+ "bottle",
22
+ "cable",
23
+ "capsule",
24
+ "hazelnut",
25
+ "metal_nut",
26
+ "pill",
27
+ "screw",
28
+ "toothbrush",
29
+ "transistor",
30
+ "zipper",
31
+ ]
32
+
33
+ IMAGENET_MEAN = [0.485, 0.456, 0.406]
34
+ IMAGENET_STD = [0.229, 0.224, 0.225]
35
+
36
+
37
+ class DatasetSplit(Enum):
38
+ TRAIN = "train"
39
+ TEST = "test"
40
+
41
+
42
+ class MVTecDataset(torch.utils.data.Dataset):
43
+ """
44
+ PyTorch Dataset for MVTec.
45
+ """
46
+
47
+ def __init__(
48
+ self,
49
+ source,
50
+ anomaly_source_path='/root/dataset/dtd/images',
51
+ dataset_name='mvtec',
52
+ classname='leather',
53
+ resize=288,
54
+ imagesize=288,
55
+ split=DatasetSplit.TRAIN,
56
+ rotate_degrees=0,
57
+ translate=0,
58
+ brightness_factor=0,
59
+ contrast_factor=0,
60
+ saturation_factor=0,
61
+ gray_p=0,
62
+ h_flip_p=0,
63
+ v_flip_p=0,
64
+ distribution=0,
65
+ mean=0.5,
66
+ std=0.1,
67
+ fg=0,
68
+ rand_aug=1,
69
+ scale=0,
70
+ batch_size=8,
71
+ **kwargs,
72
+ ):
73
+ """
74
+ Args:
75
+ source: [str]. Path to the MVTec data folder.
76
+ classname: [str or None]. Name of MVTec class that should be
77
+ provided in this dataset. If None, the datasets
78
+ iterates over all available images.
79
+ resize: [int]. (Square) Size the loaded image initially gets
80
+ resized to.
81
+ imagesize: [int]. (Square) Size the resized loaded image gets
82
+ (center-)cropped to.
83
+ split: [enum-option]. Indicates if training or test split of the
84
+ data should be used. Has to be an option taken from
85
+ DatasetSplit, e.g. mvtec.DatasetSplit.TRAIN. Note that
86
+ mvtec.DatasetSplit.TEST will also load mask data.
87
+ """
88
+ super().__init__()
89
+ self.source = source
90
+ self.split = split
91
+ self.batch_size = batch_size
92
+ self.distribution = distribution
93
+ self.mean = mean
94
+ self.std = std
95
+ self.fg = fg
96
+ self.rand_aug = rand_aug
97
+ self.resize = resize if self.distribution != 1 else [resize, resize]
98
+ self.imgsize = imagesize
99
+ self.imagesize = (3, self.imgsize, self.imgsize)
100
+ self.classname = classname
101
+ self.dataset_name = dataset_name
102
+
103
+ if self.distribution != 1 and (self.classname == 'toothbrush' or self.classname == 'wood'):
104
+ self.resize = round(self.imgsize * 329 / 288)
105
+
106
+ xlsx_path = './datasets/excel/' + self.dataset_name + '_distribution.xlsx'
107
+ if self.fg == 2: # choose by file
108
+ try:
109
+ df = pd.read_excel(xlsx_path)
110
+ self.class_fg = df.loc[df['Class'] == self.dataset_name + '_' + classname, 'Foreground'].values[0]
111
+ except:
112
+ self.class_fg = 1
113
+ elif self.fg == 1: # with foreground mask
114
+ self.class_fg = 1
115
+ else: # without foreground mask
116
+ self.class_fg = 0
117
+
118
+ self.imgpaths_per_class, self.data_to_iterate = self.get_image_data()
119
+ self.anomaly_source_paths = sorted(1 * glob.glob(anomaly_source_path + "/*/*/*/*.png") +
120
+ 0 * list(next(iter(self.imgpaths_per_class.values())).values())[0])
121
+ print(self.anomaly_source_paths)
122
+ self.transform_img = [
123
+ transforms.Resize(self.resize),
124
+ transforms.ColorJitter(brightness_factor, contrast_factor, saturation_factor),
125
+ transforms.RandomHorizontalFlip(h_flip_p),
126
+ transforms.RandomVerticalFlip(v_flip_p),
127
+ transforms.RandomGrayscale(gray_p),
128
+ transforms.RandomAffine(rotate_degrees,
129
+ translate=(translate, translate),
130
+ scale=(1.0 - scale, 1.0 + scale),
131
+ interpolation=transforms.InterpolationMode.BILINEAR),
132
+ transforms.CenterCrop(self.imgsize),
133
+ transforms.ToTensor(),
134
+ transforms.Normalize(mean=IMAGENET_MEAN, std=IMAGENET_STD),
135
+ ]
136
+ self.transform_img = transforms.Compose(self.transform_img)
137
+
138
+ self.transform_mask = [
139
+ transforms.Resize(self.resize),
140
+ transforms.CenterCrop(self.imgsize),
141
+ transforms.ToTensor(),
142
+ ]
143
+ self.transform_mask = transforms.Compose(self.transform_mask)
144
+
145
+ def rand_augmenter(self):
146
+ list_aug = [
147
+ transforms.ColorJitter(contrast=(0.8, 1.2)),
148
+ transforms.ColorJitter(brightness=(0.8, 1.2)),
149
+ transforms.ColorJitter(saturation=(0.8, 1.2), hue=(-0.2, 0.2)),
150
+ transforms.RandomHorizontalFlip(p=1),
151
+ transforms.RandomVerticalFlip(p=1),
152
+ transforms.RandomGrayscale(p=1),
153
+ transforms.RandomAutocontrast(p=1),
154
+ transforms.RandomEqualize(p=1),
155
+ transforms.RandomAffine(degrees=(-45, 45)),
156
+ ]
157
+ aug_idx = np.random.choice(np.arange(len(list_aug)), 3, replace=False)
158
+
159
+ transform_aug = [
160
+ transforms.Resize(self.resize),
161
+ list_aug[aug_idx[0]],
162
+ list_aug[aug_idx[1]],
163
+ list_aug[aug_idx[2]],
164
+ transforms.CenterCrop(self.imgsize),
165
+ transforms.ToTensor(),
166
+ transforms.Normalize(mean=IMAGENET_MEAN, std=IMAGENET_STD),
167
+ ]
168
+
169
+ transform_aug = transforms.Compose(transform_aug)
170
+ return transform_aug
171
+
172
+ def __getitem__(self, idx):
173
+ try:
174
+ classname, anomaly, image_path, mask_path = self.data_to_iterate[idx]
175
+
176
+ # Load the main image
177
+ if not os.path.exists(image_path):
178
+ LOGGER.warning(f"Image not found: {image_path}. Skipping index {idx}.")
179
+ return None
180
+
181
+ image = PIL.Image.open(image_path).convert("RGB")
182
+ image = self.transform_img(image)
183
+
184
+ # Initialize default tensors
185
+ mask_fg = mask_s = aug_image = torch.tensor([1])
186
+
187
+ if self.split == DatasetSplit.TRAIN:
188
+ try:
189
+ aug = PIL.Image.open(np.random.choice(self.anomaly_source_paths)).convert("RGB")
190
+ if self.rand_aug:
191
+ transform_aug = self.rand_augmenter()
192
+ aug = transform_aug(aug)
193
+ else:
194
+ aug = self.transform_img(aug)
195
+ except IndexError:
196
+ LOGGER.warning(f"No anomaly source images available. Using original image as augmentation for index {idx}.")
197
+ aug = image # Use original image if no anomaly source images
198
+
199
+ # Handle foreground mask
200
+ if self.class_fg:
201
+ fgmask_path = (
202
+ image_path.split(classname)[0]
203
+ + classname
204
+ + "/ground_truth/"
205
+ + os.path.split(image_path)[-1].replace(".png", "_mask.png")
206
+ )
207
+ if os.path.exists(fgmask_path):
208
+ mask_fg = PIL.Image.open(fgmask_path)
209
+ mask_fg = torch.ceil(self.transform_mask(mask_fg)[0])
210
+ else:
211
+ LOGGER.warning(f"Foreground mask not found: {fgmask_path}. Skipping mask for index {idx}.")
212
+ mask_fg = torch.zeros_like(image[0]) # Default empty mask
213
+
214
+ # Generate masks and augmented images
215
+ mask_all = perlin_mask(image.shape, self.imgsize // 8, 0, 6, mask_fg, 1)
216
+ mask_s = torch.from_numpy(mask_all[0])
217
+ mask_l = torch.from_numpy(mask_all[1])
218
+
219
+ beta = np.random.normal(loc=self.mean, scale=self.std)
220
+ beta = np.clip(beta, 0.2, 0.8)
221
+ aug_image = image * (1 - mask_l) + (1 - beta) * aug * mask_l + beta * image * mask_l
222
+
223
+ if self.split == DatasetSplit.TEST and mask_path is not None:
224
+ if os.path.exists(mask_path):
225
+ mask_gt = PIL.Image.open(mask_path).convert("L")
226
+ mask_gt = self.transform_mask(mask_gt)
227
+ else:
228
+ LOGGER.warning(f"Ground truth mask not found: {mask_path}. Using default empty mask for index {idx}.")
229
+ mask_gt = torch.zeros([1, *image.size()[1:]])
230
+ else:
231
+ mask_gt = torch.zeros([1, *image.size()[1:]])
232
+
233
+ return {
234
+ "image": image,
235
+ "aug": aug_image,
236
+ "mask_s": mask_s,
237
+ "mask_gt": mask_gt,
238
+ "is_anomaly": int(anomaly != "good"),
239
+ "image_path": image_path,
240
+ }
241
+
242
+ except Exception as e:
243
+ LOGGER.error(f"Error processing index {idx}: {e}")
244
+ return None
245
+
246
+
247
+ def __len__(self):
248
+ return len(self.data_to_iterate)
249
+
250
+ def get_image_data(self):
251
+ imgpaths_per_class = {}
252
+ maskpaths_per_class = {}
253
+
254
+ classpath = os.path.join(self.source, self.classname, self.split.value)
255
+ maskpath = os.path.join(self.source, self.classname, "ground_truth")
256
+ anomaly_types = os.listdir(classpath)
257
+
258
+ imgpaths_per_class[self.classname] = {}
259
+ maskpaths_per_class[self.classname] = {}
260
+
261
+ for anomaly in anomaly_types:
262
+ anomaly_path = os.path.join(classpath, anomaly)
263
+ anomaly_files = sorted(os.listdir(anomaly_path))
264
+ imgpaths_per_class[self.classname][anomaly] = [os.path.join(anomaly_path, x) for x in anomaly_files]
265
+
266
+ if self.split == DatasetSplit.TEST and anomaly != "good":
267
+ anomaly_mask_path = os.path.join(maskpath, anomaly)
268
+ if os.path.exists(anomaly_mask_path):
269
+ anomaly_mask_files = sorted(os.listdir(anomaly_mask_path))
270
+ maskpaths_per_class[self.classname][anomaly] = [os.path.join(anomaly_mask_path, x) for x in anomaly_mask_files]
271
+ else:
272
+ LOGGER.warning(f"Anomaly mask path does not exist: {anomaly_mask_path}. Skipping masks for {anomaly}.")
273
+ maskpaths_per_class[self.classname][anomaly] = []
274
+ else:
275
+ maskpaths_per_class[self.classname]["good"] = None
276
+
277
+ data_to_iterate = []
278
+ for classname in sorted(imgpaths_per_class.keys()):
279
+ for anomaly in sorted(imgpaths_per_class[classname].keys()):
280
+ for i, image_path in enumerate(imgpaths_per_class[classname][anomaly]):
281
+ try:
282
+ if self.split == DatasetSplit.TEST and anomaly != "good":
283
+ if i < len(maskpaths_per_class[classname][anomaly]):
284
+ mask_path = maskpaths_per_class[classname][anomaly][i]
285
+ else:
286
+ LOGGER.warning(f"No corresponding mask for {image_path}. Skipping.")
287
+ continue
288
+ else:
289
+ mask_path = None
290
+
291
+ if os.path.exists(image_path) and (mask_path is None or os.path.exists(mask_path)):
292
+ data_to_iterate.append([classname, anomaly, image_path, mask_path])
293
+ else:
294
+ LOGGER.warning(f"Missing required file for {image_path} or {mask_path}. Skipping.")
295
+ except Exception as e:
296
+ LOGGER.error(f"Error processing file {image_path}: {e}. Skipping.")
297
+
298
+ if len(data_to_iterate) == 0:
299
+ raise ValueError("No valid data found. Please check dataset paths and files.")
300
+
301
+ return imgpaths_per_class, data_to_iterate
datasets/perlin.py ADDED
@@ -0,0 +1,73 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import imgaug.augmenters as iaa
2
+ import numpy as np
3
+ import torch
4
+ import math
5
+
6
+
7
+ def generate_thr(img_shape, min=0, max=4):
8
+ min_perlin_scale = min
9
+ max_perlin_scale = max
10
+ perlin_scalex = 2 ** np.random.randint(min_perlin_scale, max_perlin_scale)
11
+ perlin_scaley = 2 ** np.random.randint(min_perlin_scale, max_perlin_scale)
12
+ perlin_noise_np = rand_perlin_2d_np((img_shape[1], img_shape[2]), (perlin_scalex, perlin_scaley))
13
+ threshold = 0.5
14
+ perlin_noise_np = iaa.Sequential([iaa.Affine(rotate=(-90, 90))])(image=perlin_noise_np)
15
+ perlin_thr = np.where(perlin_noise_np > threshold, np.ones_like(perlin_noise_np), np.zeros_like(perlin_noise_np))
16
+ return perlin_thr
17
+
18
+
19
+ def perlin_mask(img_shape, feat_size, min, max, mask_fg, flag=0):
20
+ mask = np.zeros((feat_size, feat_size))
21
+ while np.max(mask) == 0:
22
+ perlin_thr_1 = generate_thr(img_shape, min, max)
23
+ perlin_thr_2 = generate_thr(img_shape, min, max)
24
+ temp = torch.rand(1).numpy()[0]
25
+ if temp > 2 / 3:
26
+ perlin_thr = perlin_thr_1 + perlin_thr_2
27
+ perlin_thr = np.where(perlin_thr > 0, np.ones_like(perlin_thr), np.zeros_like(perlin_thr))
28
+ elif temp > 1 / 3:
29
+ perlin_thr = perlin_thr_1 * perlin_thr_2
30
+ else:
31
+ perlin_thr = perlin_thr_1
32
+ perlin_thr = torch.from_numpy(perlin_thr)
33
+ perlin_thr_fg = perlin_thr * mask_fg
34
+ down_ratio_y = int(img_shape[1] / feat_size)
35
+ down_ratio_x = int(img_shape[2] / feat_size)
36
+ mask_ = perlin_thr_fg
37
+ mask = torch.nn.functional.max_pool2d(perlin_thr_fg.unsqueeze(0).unsqueeze(0), (down_ratio_y, down_ratio_x)).float()
38
+ mask = mask.numpy()[0, 0]
39
+ mask_s = mask
40
+ if flag != 0:
41
+ mask_l = mask_.numpy()
42
+ if flag == 0:
43
+ return mask_s
44
+ else:
45
+ return mask_s, mask_l
46
+
47
+
48
+ def lerp_np(x, y, w):
49
+ fin_out = (y - x) * w + x
50
+ return fin_out
51
+
52
+
53
+ def rand_perlin_2d_np(shape, res, fade=lambda t: 6 * t ** 5 - 15 * t ** 4 + 10 * t ** 3):
54
+ delta = (res[0] / shape[0], res[1] / shape[1])
55
+ d = (shape[0] // res[0], shape[1] // res[1])
56
+ grid = np.mgrid[0:res[0]:delta[0], 0:res[1]:delta[1]].transpose(1, 2, 0) % 1
57
+
58
+ angles = 2 * math.pi * np.random.rand(res[0] + 1, res[1] + 1)
59
+ gradients = np.stack((np.cos(angles), np.sin(angles)), axis=-1)
60
+ tt = np.repeat(np.repeat(gradients, d[0], axis=0), d[1], axis=1)
61
+
62
+ tile_grads = lambda slice1, slice2: np.repeat(np.repeat(gradients[slice1[0]:slice1[1], slice2[0]:slice2[1]], d[0], axis=0), d[1],
63
+ axis=1)
64
+ dot = lambda grad, shift: (
65
+ np.stack((grid[:shape[0], :shape[1], 0] + shift[0], grid[:shape[0], :shape[1], 1] + shift[1]),
66
+ axis=-1) * grad[:shape[0], :shape[1]]).sum(axis=-1)
67
+
68
+ n00 = dot(tile_grads([0, -1], [0, -1]), [0, 0])
69
+ n10 = dot(tile_grads([1, None], [0, -1]), [-1, 0])
70
+ n01 = dot(tile_grads([0, -1], [1, None]), [0, -1])
71
+ n11 = dot(tile_grads([1, None], [1, None]), [-1, -1])
72
+ t = fade(grid[:shape[0], :shape[1]])
73
+ return math.sqrt(2) * lerp_np(lerp_np(n00, n10, t[..., 0]), lerp_np(n01, n11, t[..., 0]), t[..., 1])
datasets/rayan_dataset.py ADDED
@@ -0,0 +1,127 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -----------------------------------------------------------------------------
2
+ # Do Not Alter This File!
3
+ # -----------------------------------------------------------------------------
4
+ # The following code is part of the logic used for loading and evaluating your
5
+ # output scores. Please DO NOT modify this section, as upon your submission,
6
+ # the whole evaluation logic will be overwritten by the original code.
7
+ # -----------------------------------------------------------------------------
8
+ # If you'd like to make modifications, you can create a completely new Dataset
9
+ # class or a child class that inherits from this one and use that with your
10
+ # data loader.
11
+ # -----------------------------------------------------------------------------
12
+
13
+ import os
14
+ from enum import Enum
15
+
16
+ import PIL
17
+ import torch
18
+ from torchvision import transforms
19
+
20
+ IMAGENET_MEAN = [0.485, 0.456, 0.406]
21
+ IMAGENET_STD = [0.229, 0.224, 0.225]
22
+
23
+
24
+ class DatasetSplit(Enum):
25
+ TRAIN = "train"
26
+ VAL = "val"
27
+ TEST = "test"
28
+
29
+
30
+ class RayanDataset(torch.utils.data.Dataset):
31
+ def __init__(
32
+ self,
33
+ source,
34
+ classname,
35
+ input_size=518,
36
+ output_size=224,
37
+ split=DatasetSplit.TEST,
38
+ external_transform=None,
39
+ **kwargs,
40
+ ):
41
+ super().__init__()
42
+ self.source = source
43
+ self.split = split
44
+ self.classnames_to_use = [classname]
45
+ self.imgpaths_per_class, self.data_to_iterate = self.get_image_data()
46
+
47
+ if external_transform is None:
48
+ self.transform_img = [
49
+ transforms.Resize((input_size, input_size)),
50
+ transforms.CenterCrop(input_size),
51
+ transforms.ToTensor(),
52
+ transforms.Normalize(mean=IMAGENET_MEAN, std=IMAGENET_STD),
53
+ ]
54
+ self.transform_img = transforms.Compose(self.transform_img)
55
+ else:
56
+ self.transform_img = external_transform
57
+
58
+ # Output size of the mask has to be of shape: 1×224×224
59
+ self.transform_mask = [
60
+ transforms.Resize((output_size, output_size)),
61
+ transforms.CenterCrop(output_size),
62
+ transforms.ToTensor(),
63
+ ]
64
+ self.transform_mask = transforms.Compose(self.transform_mask)
65
+ self.output_shape = (1, output_size, output_size)
66
+
67
+ def __getitem__(self, idx):
68
+ classname, anomaly, image_path, mask_path = self.data_to_iterate[idx]
69
+ image = PIL.Image.open(image_path).convert("RGB")
70
+ image = self.transform_img(image)
71
+
72
+ if self.split == DatasetSplit.TEST and mask_path is not None:
73
+ mask = PIL.Image.open(mask_path).convert("L")
74
+ mask = self.transform_mask(mask) > 0
75
+ else:
76
+ mask = torch.zeros([*self.output_shape])
77
+
78
+ return {
79
+ "image": image,
80
+ "mask": mask,
81
+ "is_anomaly": int(anomaly != "good"),
82
+ "image_path": image_path,
83
+ }
84
+
85
+ def __len__(self):
86
+ return len(self.data_to_iterate)
87
+
88
+ def get_image_data(self):
89
+ imgpaths_per_class = {}
90
+ maskpaths_per_class = {}
91
+
92
+ for classname in self.classnames_to_use:
93
+ classpath = os.path.join(self.source, classname, self.split.value)
94
+ maskpath = os.path.join(self.source, classname, "ground_truth")
95
+ anomaly_types = os.listdir(classpath)
96
+
97
+ imgpaths_per_class[classname] = {}
98
+ maskpaths_per_class[classname] = {}
99
+
100
+ for anomaly in anomaly_types:
101
+ anomaly_path = os.path.join(classpath, anomaly)
102
+ anomaly_files = sorted(os.listdir(anomaly_path))
103
+ imgpaths_per_class[classname][anomaly] = [
104
+ os.path.join(anomaly_path, x) for x in anomaly_files
105
+ ]
106
+
107
+ if self.split == DatasetSplit.TEST and anomaly != "good":
108
+ anomaly_mask_path = os.path.join(maskpath, anomaly)
109
+ anomaly_mask_files = sorted(os.listdir(anomaly_mask_path))
110
+ maskpaths_per_class[classname][anomaly] = [
111
+ os.path.join(anomaly_mask_path, x) for x in anomaly_mask_files
112
+ ]
113
+ else:
114
+ maskpaths_per_class[classname]["good"] = None
115
+
116
+ data_to_iterate = []
117
+ for classname in sorted(imgpaths_per_class.keys()):
118
+ for anomaly in sorted(imgpaths_per_class[classname].keys()):
119
+ for i, image_path in enumerate(imgpaths_per_class[classname][anomaly]):
120
+ data_tuple = [classname, anomaly, image_path]
121
+ if self.split == DatasetSplit.TEST and anomaly != "good":
122
+ data_tuple.append(maskpaths_per_class[classname][anomaly][i])
123
+ else:
124
+ data_tuple.append(None)
125
+ data_to_iterate.append(data_tuple)
126
+
127
+ return imgpaths_per_class, data_to_iterate
docker-compose.yml ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -----------------------------------------------------------------------------
2
+ # A sample Docker Compose file to help you replicate our test environment
3
+ # -----------------------------------------------------------------------------
4
+
5
+ services:
6
+ zsad-service:
7
+ image: zsad-image:1
8
+ build:
9
+ context: .
10
+ container_name: zsad-container
11
+ volumes:
12
+ - ./shared_folder:/app/output
13
+ deploy:
14
+ resources:
15
+ reservations:
16
+ devices:
17
+ - driver: nvidia
18
+ count: all
19
+ capabilities: [gpu]
20
+
21
+ command: [ "python3", "runner.py" ]
evaluation/base_eval.py ADDED
@@ -0,0 +1,293 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -----------------------------------------------------------------------------
2
+ # Do Not Alter This File!
3
+ # -----------------------------------------------------------------------------
4
+ # The following code is part of the logic used for loading and evaluating your
5
+ # output scores. Please DO NOT modify this section, as upon your submission,
6
+ # the whole evaluation logic will be overwritten by the original code.
7
+ # -----------------------------------------------------------------------------
8
+
9
+ import warnings
10
+ import os
11
+ from pathlib import Path
12
+ import csv
13
+ import json
14
+ import torch
15
+
16
+ import datasets.rayan_dataset as rayan_dataset
17
+ from evaluation.utils.metrics import compute_metrics
18
+
19
+ warnings.filterwarnings("ignore")
20
+
21
+
22
+ class BaseEval:
23
+ def __init__(self, cfg):
24
+ self.cfg = cfg
25
+ self.device = torch.device(
26
+ "cuda:{}".format(cfg["device"]) if torch.cuda.is_available() else "cpu"
27
+ )
28
+
29
+ self.path = cfg["datasets"]["data_path"]
30
+ self.dataset = cfg["datasets"]["dataset_name"]
31
+ self.save_csv = cfg["testing"]["save_csv"]
32
+ self.save_json = cfg["testing"]["save_json"]
33
+ self.categories = cfg["datasets"]["class_name"]
34
+ if isinstance(self.categories, str):
35
+ if self.categories.lower() == "all":
36
+ if self.dataset == "rayan_dataset":
37
+ self.categories = self.get_available_class_names(self.path)
38
+ else:
39
+ self.categories = [self.categories]
40
+ self.output_dir = cfg["testing"]["output_dir"]
41
+ os.makedirs(self.output_dir, exist_ok=True)
42
+ self.scores_dir = cfg["testing"]["output_scores_dir"]
43
+ self.class_name_mapping_dir = cfg["testing"]["class_name_mapping_dir"]
44
+
45
+ self.leaderboard_metric_weights = {
46
+ "image_auroc": 1.2,
47
+ "image_ap": 1.1,
48
+ "image_f1": 1.1,
49
+ "pixel_auroc": 1.0,
50
+ "pixel_aupro": 1.4,
51
+ "pixel_ap": 1.3,
52
+ "pixel_f1": 1.3,
53
+ }
54
+
55
+ def get_available_class_names(self, root_data_path):
56
+ all_items = os.listdir(root_data_path)
57
+ folder_names = [
58
+ item
59
+ for item in all_items
60
+ if os.path.isdir(os.path.join(root_data_path, item))
61
+ ]
62
+
63
+ return folder_names
64
+
65
+ def load_datasets(self, category):
66
+ dataset_classes = {
67
+ "rayan_dataset": rayan_dataset.RayanDataset,
68
+ }
69
+
70
+ dataset_splits = {
71
+ "rayan_dataset": rayan_dataset.DatasetSplit.TEST,
72
+ }
73
+
74
+ test_dataset = dataset_classes[self.dataset](
75
+ source=self.path,
76
+ split=dataset_splits[self.dataset],
77
+ classname=category,
78
+ )
79
+ return test_dataset
80
+
81
+ def get_category_metrics(self, category):
82
+ print(f"Loading scores of '{category}'")
83
+ gt_sp, pr_sp, gt_px, pr_px, _ = self.load_category_scores(category)
84
+
85
+ print(f"Computing metrics for '{category}'")
86
+ image_metric, pixel_metric = compute_metrics(gt_sp, pr_sp, gt_px, pr_px)
87
+
88
+ return image_metric, pixel_metric
89
+
90
+ def load_category_scores(self, category):
91
+ raise NotImplementedError()
92
+
93
+ def get_scores_path_for_image(self, image_path):
94
+ """example image_path: './data/photovoltaic_module/test/good/037.png'"""
95
+ path = Path(image_path)
96
+
97
+ category, split, anomaly_type = path.parts[-4:-1]
98
+ image_name = path.stem
99
+
100
+ return os.path.join(
101
+ self.scores_dir, category, split, anomaly_type, f"{image_name}_scores.json"
102
+ )
103
+
104
+ def calc_leaderboard_score(self, **metrics):
105
+ weighted_sum = 0
106
+ total_weight = 0
107
+ for key, weight in self.leaderboard_metric_weights.items():
108
+ metric = metrics.get(key)
109
+ weighted_sum += metric * weight
110
+ total_weight += weight
111
+
112
+ if total_weight == 0:
113
+ return 0
114
+
115
+ return weighted_sum / total_weight
116
+
117
+ def main(self):
118
+ image_auroc_list = []
119
+ image_f1_list = []
120
+ image_ap_list = []
121
+ pixel_auroc_list = []
122
+ pixel_f1_list = []
123
+ pixel_ap_list = []
124
+ pixel_aupro_list = []
125
+ leaderboard_score_list = []
126
+ for category in self.categories:
127
+ image_metric, pixel_metric = self.get_category_metrics(
128
+ category=category,
129
+ )
130
+ image_auroc, image_f1, image_ap = image_metric
131
+ pixel_auroc, pixel_f1, pixel_ap, pixel_aupro = pixel_metric
132
+ leaderboard_score = self.calc_leaderboard_score(
133
+ image_auroc=image_auroc,
134
+ image_f1=image_f1,
135
+ image_ap=image_ap,
136
+ pixel_auroc=pixel_auroc,
137
+ pixel_aupro=pixel_aupro,
138
+ pixel_f1=pixel_f1,
139
+ pixel_ap=pixel_ap,
140
+ )
141
+
142
+ image_auroc_list.append(image_auroc)
143
+ image_f1_list.append(image_f1)
144
+ image_ap_list.append(image_ap)
145
+ pixel_auroc_list.append(pixel_auroc)
146
+ pixel_f1_list.append(pixel_f1)
147
+ pixel_ap_list.append(pixel_ap)
148
+ pixel_aupro_list.append(pixel_aupro)
149
+ leaderboard_score_list.append(leaderboard_score)
150
+
151
+ print(category)
152
+ print(
153
+ "[image level] auroc:{}, f1:{}, ap:{}".format(
154
+ image_auroc * 100,
155
+ image_f1 * 100,
156
+ image_ap * 100,
157
+ )
158
+ )
159
+ print(
160
+ "[pixel level] auroc:{}, f1:{}, ap:{}, aupro:{}".format(
161
+ pixel_auroc * 100,
162
+ pixel_f1 * 100,
163
+ pixel_ap * 100,
164
+ pixel_aupro * 100,
165
+ )
166
+ )
167
+ print(
168
+ "leaderboard score:{}".format(
169
+ leaderboard_score * 100,
170
+ )
171
+ )
172
+
173
+ image_auroc_mean = sum(image_auroc_list) / len(image_auroc_list)
174
+ image_f1_mean = sum(image_f1_list) / len(image_f1_list)
175
+ image_ap_mean = sum(image_ap_list) / len(image_ap_list)
176
+ pixel_auroc_mean = sum(pixel_auroc_list) / len(pixel_auroc_list)
177
+ pixel_f1_mean = sum(pixel_f1_list) / len(pixel_f1_list)
178
+ pixel_ap_mean = sum(pixel_ap_list) / len(pixel_ap_list)
179
+ pixel_aupro_mean = sum(pixel_aupro_list) / len(pixel_aupro_list)
180
+ leaderboard_score_mean = sum(leaderboard_score_list) / len(
181
+ leaderboard_score_list
182
+ )
183
+
184
+ print("mean")
185
+ print(
186
+ "[image level] auroc:{}, f1:{}, ap:{}".format(
187
+ image_auroc_mean * 100, image_f1_mean * 100, image_ap_mean * 100
188
+ )
189
+ )
190
+ print(
191
+ "[pixel level] auroc:{}, f1:{}, ap:{}, aupro:{}".format(
192
+ pixel_auroc_mean * 100,
193
+ pixel_f1_mean * 100,
194
+ pixel_ap_mean * 100,
195
+ pixel_aupro_mean * 100,
196
+ )
197
+ )
198
+ print(
199
+ "leaderboard score:{}".format(
200
+ leaderboard_score_mean * 100,
201
+ )
202
+ )
203
+
204
+ # Save the final results as a csv file
205
+ if self.save_csv:
206
+ with open(self.class_name_mapping_dir, "r") as f:
207
+ class_name_mapping_dict = json.load(f)
208
+ csv_data = [
209
+ [
210
+ "Category",
211
+ "pixel_auroc",
212
+ "pixel_f1",
213
+ "pixel_ap",
214
+ "pixel_aupro",
215
+ "image_auroc",
216
+ "image_f1",
217
+ "image_ap",
218
+ "leaderboard_score",
219
+ ]
220
+ ]
221
+ for i, category in enumerate(self.categories):
222
+ csv_data.append(
223
+ [
224
+ class_name_mapping_dict[category],
225
+ pixel_auroc_list[i] * 100,
226
+ pixel_f1_list[i] * 100,
227
+ pixel_ap_list[i] * 100,
228
+ pixel_aupro_list[i] * 100,
229
+ image_auroc_list[i] * 100,
230
+ image_f1_list[i] * 100,
231
+ image_ap_list[i] * 100,
232
+ leaderboard_score_list[i] * 100,
233
+ ]
234
+ )
235
+ csv_data.append(
236
+ [
237
+ "mean",
238
+ pixel_auroc_mean * 100,
239
+ pixel_f1_mean * 100,
240
+ pixel_ap_mean * 100,
241
+ pixel_aupro_mean * 100,
242
+ image_auroc_mean * 100,
243
+ image_f1_mean * 100,
244
+ image_ap_mean * 100,
245
+ leaderboard_score_mean * 100,
246
+ ]
247
+ )
248
+
249
+ csv_file_path = os.path.join(self.output_dir, "results.csv")
250
+ with open(csv_file_path, mode="w", newline="") as file:
251
+ writer = csv.writer(file)
252
+ writer.writerows(csv_data)
253
+
254
+ # Save the final results as a json file
255
+ if self.save_json:
256
+ json_data = []
257
+ with open(self.class_name_mapping_dir, "r") as f:
258
+ class_name_mapping_dict = json.load(f)
259
+ for i, category in enumerate(self.categories):
260
+ json_data.append(
261
+ {
262
+ "Category": class_name_mapping_dict[category],
263
+ "pixel_auroc": pixel_auroc_list[i] * 100,
264
+ "pixel_f1": pixel_f1_list[i] * 100,
265
+ "pixel_ap": pixel_ap_list[i] * 100,
266
+ "pixel_aupro": pixel_aupro_list[i] * 100,
267
+ "image_auroc": image_auroc_list[i] * 100,
268
+ "image_f1": image_f1_list[i] * 100,
269
+ "image_ap": image_ap_list[i] * 100,
270
+ "leaderboard_score": leaderboard_score_list[i] * 100,
271
+ }
272
+ )
273
+ json_data.append(
274
+ {
275
+ "Category": "mean",
276
+ "pixel_auroc": pixel_auroc_mean * 100,
277
+ "pixel_f1": pixel_f1_mean * 100,
278
+ "pixel_ap": pixel_ap_mean * 100,
279
+ "pixel_aupro": pixel_aupro_mean * 100,
280
+ "image_auroc": image_auroc_mean * 100,
281
+ "image_f1": image_f1_mean * 100,
282
+ "image_ap": image_ap_mean * 100,
283
+ "leaderboard_score": leaderboard_score_mean * 100,
284
+ }
285
+ )
286
+
287
+ json_file_path = os.path.join(self.output_dir, "results.json")
288
+ with open(json_file_path, mode="w") as file:
289
+ final_json = {
290
+ "result": leaderboard_score_mean * 100,
291
+ "metadata": json_data,
292
+ }
293
+ json.dump(final_json, file, indent=4)
evaluation/class_name_mapping.json ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ {
2
+ "pill": "industrial_01",
3
+ "photovoltaic_module": "industrial_02",
4
+ "capsules": "industrial_03"
5
+ }
evaluation/eval_main.py ADDED
@@ -0,0 +1,78 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -----------------------------------------------------------------------------
2
+ # Do Not Alter This File!
3
+ # -----------------------------------------------------------------------------
4
+ # The following code is part of the logic used for loading and evaluating your
5
+ # output scores. Please DO NOT modify this section, as upon your submission,
6
+ # the whole evaluation logic will be overwritten by the original code.
7
+ # -----------------------------------------------------------------------------
8
+
9
+ import warnings
10
+ import argparse
11
+ import os
12
+ import sys
13
+
14
+ sys.path.append(os.getcwd())
15
+ from evaluation.json_score import JsonScoreEvaluator
16
+
17
+ warnings.filterwarnings("ignore")
18
+
19
+
20
+ def get_args():
21
+ parser = argparse.ArgumentParser(description="Rayan ZSAD Evaluation Code")
22
+ parser.add_argument("--data_path", type=str, default=None, help="dataset path")
23
+ parser.add_argument("--dataset_name", type=str, default=None, help="dataset name")
24
+ parser.add_argument("--class_name", type=str, default=None, help="category")
25
+ parser.add_argument("--device", type=int, default=None, help="gpu id")
26
+ parser.add_argument(
27
+ "--output_dir", type=str, default=None, help="save results path"
28
+ )
29
+ parser.add_argument(
30
+ "--output_scores_dir", type=str, default=None, help="save scores path"
31
+ )
32
+ parser.add_argument("--save_csv", type=str, default=None, help="save csv")
33
+ parser.add_argument("--save_json", type=str, default=None, help="save json")
34
+
35
+ parser.add_argument(
36
+ "--class_name_mapping_dir",
37
+ type=str,
38
+ default=None,
39
+ help="mapping from actual class names to class numbers",
40
+ )
41
+ args = parser.parse_args()
42
+ return args
43
+
44
+
45
+ def load_args(cfg, args):
46
+ cfg["datasets"]["data_path"] = args.data_path
47
+ assert os.path.exists(
48
+ cfg["datasets"]["data_path"]
49
+ ), f"The dataset path {cfg['datasets']['data_path']} does not exist."
50
+ cfg["datasets"]["dataset_name"] = args.dataset_name
51
+ cfg["datasets"]["class_name"] = args.class_name
52
+ cfg["device"] = args.device
53
+ if isinstance(cfg["device"], int):
54
+ cfg["device"] = str(cfg["device"])
55
+ cfg["testing"]["output_dir"] = args.output_dir
56
+ cfg["testing"]["output_scores_dir"] = args.output_scores_dir
57
+ os.makedirs(cfg["testing"]["output_scores_dir"], exist_ok=True)
58
+
59
+ cfg["testing"]["class_name_mapping_dir"] = args.class_name_mapping_dir
60
+ if args.save_csv.lower() == "true":
61
+ cfg["testing"]["save_csv"] = True
62
+ else:
63
+ cfg["testing"]["save_csv"] = False
64
+
65
+ if args.save_json.lower() == "true":
66
+ cfg["testing"]["save_json"] = True
67
+ else:
68
+ cfg["testing"]["save_json"] = False
69
+
70
+ return cfg
71
+
72
+
73
+ if __name__ == "__main__":
74
+ args = get_args()
75
+ cfg = load_args(cfg={"datasets": {}, "testing": {}, "models": {}}, args=args)
76
+ print(cfg)
77
+ model = JsonScoreEvaluator(cfg=cfg)
78
+ model.main()
evaluation/json_score.py ADDED
@@ -0,0 +1,98 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -----------------------------------------------------------------------------
2
+ # Do Not Alter This File!
3
+ # -----------------------------------------------------------------------------
4
+ # The following code is part of the logic used for loading and evaluating your
5
+ # output scores. Please DO NOT modify this section, as upon your submission,
6
+ # the whole evaluation logic will be overwritten by the original code.
7
+ # -----------------------------------------------------------------------------
8
+
9
+ import warnings
10
+ import numpy as np
11
+ import torch
12
+ from tqdm import tqdm
13
+
14
+ from evaluation.base_eval import BaseEval
15
+ from evaluation.utils.json_helpers import json_to_dict
16
+
17
+ warnings.filterwarnings("ignore")
18
+
19
+
20
+ class JsonScoreEvaluator(BaseEval):
21
+ """
22
+ Evaluates anomaly detection performance based on pre-computed scores stored in JSON files.
23
+
24
+ This class extends the BaseEval class and specializes in reading scores from JSON files,
25
+ computing evaluation metrics, and optionally saving results to CSV or JSON format.
26
+
27
+ Notes:
28
+ - Score files are expected to follow the exact dataset structure.
29
+ `{category}/{split}/{anomaly_type}/{image_name}_scores.json`
30
+ e.g., `photovoltaic_module/test/good/037_scores.json`
31
+ - Score files are expected to be at `self.scores_dir`.
32
+
33
+ Example usage:
34
+ >>> evaluator = JsonScoreEvaluator(cfg)
35
+ >>> results = evaluator.main()
36
+ """
37
+
38
+ def __init__(self, cfg):
39
+ super().__init__(cfg)
40
+
41
+ def get_scores_for_image(self, image_path):
42
+ image_scores_path = self.get_scores_path_for_image(image_path)
43
+ image_scores = json_to_dict(image_scores_path)
44
+
45
+ return image_scores
46
+
47
+ def load_category_scores(self, category):
48
+ cls_scores_list = [] # image level prediction
49
+ anomaly_maps = [] # pixel level prediction
50
+ gt_list = [] # image level ground truth
51
+ img_masks = [] # pixel level ground truth
52
+
53
+ image_path_list = []
54
+ test_dataset = self.load_datasets(category)
55
+ test_dataloader = torch.utils.data.DataLoader(
56
+ test_dataset,
57
+ batch_size=1,
58
+ shuffle=False,
59
+ num_workers=0,
60
+ pin_memory=True,
61
+ )
62
+
63
+ for image_info in tqdm(test_dataloader):
64
+ if not isinstance(image_info, dict):
65
+ raise ValueError("Encountered non-dict image in dataloader")
66
+
67
+ del image_info["image"]
68
+
69
+ image_path = image_info["image_path"][0]
70
+ image_path_list.extend(image_path)
71
+
72
+ img_masks.append(image_info["mask"])
73
+ gt_list.extend(list(image_info["is_anomaly"].numpy()))
74
+
75
+ image_scores = self.get_scores_for_image(image_path)
76
+ cls_scores = image_scores["img_level_score"]
77
+ anomaly_maps_iter = image_scores["pix_level_score"]
78
+
79
+ cls_scores_list.append(cls_scores)
80
+ anomaly_maps.append(anomaly_maps_iter)
81
+
82
+ pr_sp = np.array(cls_scores_list)
83
+ gt_sp = np.array(gt_list)
84
+ pr_px = np.array(anomaly_maps)
85
+ gt_px = torch.cat(img_masks, dim=0).numpy().astype(np.int32)
86
+ print(pr_px.shape)
87
+ assert pr_px.shape[1:] == (
88
+ 1,
89
+ 224,
90
+ 224,
91
+ ), "Predicted output scores do not meet the expected shape!"
92
+ assert gt_px.shape[1:] == (
93
+ 1,
94
+ 224,
95
+ 224,
96
+ ), "Loaded ground truth maps do not meet the expected shape!"
97
+
98
+ return gt_sp, pr_sp, gt_px, pr_px, image_path_list
evaluation/utils/json_helpers.py ADDED
@@ -0,0 +1,46 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -----------------------------------------------------------------------------
2
+ # Do Not Alter This File!
3
+ # -----------------------------------------------------------------------------
4
+ # The following code is part of the logic used for loading and evaluating your
5
+ # output scores. Please DO NOT modify this section, as upon your submission,
6
+ # the whole evaluation logic will be overwritten by the original code.
7
+ # -----------------------------------------------------------------------------
8
+
9
+ import json
10
+ import numpy as np
11
+
12
+
13
+ class NumpyEncoder(json.JSONEncoder):
14
+ """Special json encoder for numpy types"""
15
+
16
+ def default(self, obj):
17
+ if isinstance(obj, np.integer):
18
+ return int(obj)
19
+ elif isinstance(obj, np.floating):
20
+ return float(obj)
21
+ elif isinstance(obj, np.ndarray):
22
+ return {
23
+ "__ndarray__": obj.tolist(),
24
+ "dtype": str(obj.dtype),
25
+ "shape": obj.shape,
26
+ }
27
+ else:
28
+ return super(NumpyEncoder, self).default(obj)
29
+
30
+
31
+ def dict_to_json(dct, filename):
32
+ """Save a dictionary to a JSON file"""
33
+ with open(filename, "w") as f:
34
+ json.dump(dct, f, cls=NumpyEncoder)
35
+
36
+
37
+ def json_to_dict(filename):
38
+ """Load a JSON file and convert it back to a dictionary of NumPy arrays"""
39
+ with open(filename, "r") as f:
40
+ dct = json.load(f)
41
+
42
+ for k, v in dct.items():
43
+ if isinstance(v, dict) and "__ndarray__" in v:
44
+ dct[k] = np.array(v["__ndarray__"], dtype=v["dtype"]).reshape(v["shape"])
45
+
46
+ return dct
evaluation/utils/metrics.py ADDED
@@ -0,0 +1,111 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -----------------------------------------------------------------------------
2
+ # Do Not Alter This File!
3
+ # -----------------------------------------------------------------------------
4
+ # The following code is part of the logic used for loading and evaluating your
5
+ # output scores. Please DO NOT modify this section, as upon your submission,
6
+ # the whole evaluation logic will be overwritten by the original code.
7
+ # -----------------------------------------------------------------------------
8
+
9
+ import numpy as np
10
+ from sklearn.metrics import (
11
+ auc,
12
+ roc_auc_score,
13
+ average_precision_score,
14
+ precision_recall_curve,
15
+ )
16
+ from skimage import measure
17
+ import warnings
18
+
19
+ # ref: https://github.com/gudovskiy/cflow-ad/blob/master/train.py
20
+ def cal_pro_score(masks, amaps, max_step=200, expect_fpr=0.3):
21
+ binary_amaps = np.zeros_like(amaps, dtype=bool)
22
+ min_th, max_th = amaps.min(), amaps.max()
23
+ delta = (max_th - min_th) / max_step
24
+ pros, fprs, ths = [], [], []
25
+ for th in np.arange(min_th, max_th, delta):
26
+ binary_amaps[amaps <= th], binary_amaps[amaps > th] = 0, 1
27
+ pro = []
28
+ for binary_amap, mask in zip(binary_amaps, masks):
29
+ for region in measure.regionprops(measure.label(mask)):
30
+ tp_pixels = binary_amap[region.coords[:, 0], region.coords[:, 1]].sum()
31
+ pro.append(tp_pixels / region.area)
32
+ inverse_masks = 1 - masks
33
+ fp_pixels = np.logical_and(inverse_masks, binary_amaps).sum()
34
+ fpr = fp_pixels / inverse_masks.sum()
35
+ pros.append(np.array(pro).mean())
36
+ fprs.append(fpr)
37
+ ths.append(th)
38
+ pros, fprs, ths = np.array(pros), np.array(fprs), np.array(ths)
39
+ idxes = fprs < expect_fpr
40
+ fprs = fprs[idxes]
41
+ print("fprs: ", fprs)
42
+ fprs = (fprs - fprs.min()) / (fprs.max() - fprs.min())
43
+ pro_auc = auc(fprs, pros[idxes])
44
+ return pro_auc
45
+
46
+
47
+ def compute_metrics(gt_sp=None, pr_sp=None, gt_px=None, pr_px=None):
48
+ # classification
49
+ if (
50
+ gt_sp is None
51
+ or pr_sp is None
52
+ or gt_sp.sum() == 0
53
+ or gt_sp.sum() == gt_sp.shape[0]
54
+ ):
55
+ auroc_sp, f1_sp, ap_sp = 0, 0, 0
56
+ else:
57
+ auroc_sp = roc_auc_score(gt_sp, pr_sp)
58
+ ap_sp = average_precision_score(gt_sp, pr_sp)
59
+ precisions, recalls, thresholds = precision_recall_curve(gt_sp, pr_sp)
60
+ f1_scores = (2 * precisions * recalls) / (precisions + recalls)
61
+ f1_sp = np.max(f1_scores[np.isfinite(f1_scores)])
62
+
63
+ # segmentation
64
+ if gt_px is None or pr_px is None or gt_px.sum() == 0:
65
+ auroc_px, f1_px, ap_px, aupro = 0, 0, 0, 0
66
+ else:
67
+ auroc_px = roc_auc_score(gt_px.ravel(), pr_px.ravel())
68
+ ap_px = average_precision_score(gt_px.ravel(), pr_px.ravel())
69
+ precisions, recalls, thresholds = precision_recall_curve(
70
+ gt_px.ravel(), pr_px.ravel()
71
+ )
72
+ f1_scores = (2 * precisions * recalls) / (precisions + recalls)
73
+ f1_px = np.max(f1_scores[np.isfinite(f1_scores)])
74
+ aupro = cal_pro_score(gt_px.squeeze(), pr_px.squeeze())
75
+
76
+ image_metric = [auroc_sp, f1_sp, ap_sp]
77
+ pixel_metric = [auroc_px, f1_px, ap_px, aupro]
78
+
79
+ return image_metric, pixel_metric
80
+
81
+ def compute_auroc(labels, scores):
82
+ """
83
+ Computes the Area Under the Receiver Operating Characteristic Curve (AUROC).
84
+
85
+ Args:
86
+ labels (list or np.ndarray): True binary labels (0 for normal, 1 for anomaly).
87
+ scores (list or np.ndarray): Predicted scores or probabilities for the positive class.
88
+
89
+ Returns:
90
+ float: AUROC score. Returns None if AUROC is undefined.
91
+ """
92
+ # Convert inputs to numpy arrays
93
+ labels = np.array(labels)
94
+ scores = np.array(scores)
95
+
96
+ # Ensure that labels are binary
97
+ unique_labels = np.unique(labels)
98
+ if set(unique_labels) != {0, 1}:
99
+ raise ValueError(f"Labels must be binary (0 and 1). Found labels: {unique_labels}")
100
+
101
+ # Check if both classes are present
102
+ if len(unique_labels) < 2:
103
+ warnings.warn("Only one class present in labels. AUROC is undefined.")
104
+ return None
105
+
106
+ try:
107
+ auroc = roc_auc_score(labels, scores)
108
+ return auroc
109
+ except ValueError as e:
110
+ warnings.warn(f"Error computing AUROC: {e}")
111
+ return None
main.py ADDED
@@ -0,0 +1,308 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # main.py
2
+
3
+ import os
4
+ import torch
5
+ from torch.utils.data import DataLoader
6
+ from datasets.all_classes_dataset import AllClassesDataset, DatasetSplit
7
+ from models.anomaly_detector import AnomalyDetector
8
+ from utils.dump_scores import DumpScores
9
+ import logging
10
+ import json
11
+ from sklearn.metrics import average_precision_score, roc_auc_score, f1_score
12
+ import numpy as np
13
+ import torch.nn.functional as F
14
+ import random
15
+
16
+ def set_seed(seed: int):
17
+ """
18
+ Set the seed for reproducibility across various libraries.
19
+
20
+ Args:
21
+ seed (int): The seed value to be set.
22
+ """
23
+ random.seed(seed)
24
+ np.random.seed(seed)
25
+ torch.manual_seed(seed)
26
+
27
+ if torch.cuda.is_available():
28
+ torch.cuda.manual_seed(seed)
29
+ torch.cuda.manual_seed_all(seed) # For multi-GPU setups
30
+
31
+ # Ensure deterministic behavior in PyTorch
32
+ torch.backends.cudnn.deterministic = True
33
+ torch.backends.cudnn.benchmark = False
34
+
35
+ # For DataLoader workers
36
+ os.environ['PYTHONHASHSEED'] = str(seed)
37
+
38
+ def worker_init_fn(worker_id):
39
+ """
40
+ Initialize the seed for each DataLoader worker to ensure reproducibility.
41
+
42
+ Args:
43
+ worker_id (int): The worker ID.
44
+ """
45
+ seed = torch.initial_seed()
46
+ np.random.seed(seed % 2**32)
47
+ random.seed(seed % 2**32)
48
+
49
+ def compute_aupro(y_true_pixel, y_scores_pixel, num_thresholds=50):
50
+ """
51
+ Compute Area Under the Per-Region Overlap Curve (AUPRO).
52
+
53
+ Args:
54
+ y_true_pixel (np.ndarray): Ground truth binary masks, shape [N, H, W]
55
+ y_scores_pixel (np.ndarray): Predicted anomaly scores, shape [N, H, W]
56
+ num_thresholds (int): Number of thresholds to evaluate.
57
+
58
+ Returns:
59
+ float: AUPRO score.
60
+ """
61
+ # Define thresholds
62
+ thresholds = np.linspace(0, 1, num_thresholds)
63
+
64
+ # Initialize list to store overlaps
65
+ overlaps = []
66
+
67
+ for thresh in thresholds:
68
+ # Binarize predictions
69
+ y_pred = (y_scores_pixel >= thresh).astype(int)
70
+
71
+ # Compute Intersection over Union (IoU) for each sample
72
+ ious = []
73
+ for gt, pred in zip(y_true_pixel, y_pred):
74
+ intersection = np.logical_and(gt, pred).sum()
75
+ union = np.logical_or(gt, pred).sum()
76
+ if union == 0:
77
+ iou = 1.0 # If both gt and pred are all zeros
78
+ else:
79
+ iou = intersection / union
80
+ ious.append(iou)
81
+
82
+ # Average IoU over all samples
83
+ avg_iou = np.mean(ious)
84
+ overlaps.append(avg_iou)
85
+
86
+ # Compute the area under the overlap curve
87
+ aupro = np.trapz(overlaps, thresholds) / np.trapz([1] * len(thresholds), thresholds) # Normalize
88
+ return aupro
89
+
90
+
91
+ def compute_metrics(y_true_image, y_scores_image, y_true_pixel, y_scores_pixel):
92
+ """
93
+ Compute the required metrics based on true labels and predicted scores.
94
+
95
+ Args:
96
+ y_true_image (np.ndarray): Ground truth image labels, shape [N]
97
+ y_scores_image (np.ndarray): Predicted image scores, shape [N]
98
+ y_true_pixel (np.ndarray): Ground truth pixel masks, shape [N, H, W]
99
+ y_scores_pixel (np.ndarray): Predicted pixel anomaly scores, shape [N, H, W]
100
+
101
+ Returns:
102
+ dict: Dictionary containing computed metrics.
103
+ """
104
+ # Check image-level consistency
105
+ if len(y_true_image) != len(y_scores_image):
106
+ raise ValueError(f"Image-level y_true and y_scores have different lengths: {len(y_true_image)} vs {len(y_scores_image)}")
107
+
108
+ # Check pixel-level consistency
109
+ if y_true_pixel.shape != y_scores_pixel.shape:
110
+ raise ValueError(f"Pixel-level y_true and y_scores have different shapes: {y_true_pixel.shape} vs {y_scores_pixel.shape}")
111
+
112
+ # Image-level Metrics
113
+ image_ap = average_precision_score(y_true_image, y_scores_image)
114
+ image_auroc = roc_auc_score(y_true_image, y_scores_image)
115
+ y_pred_image = (y_scores_image >= 0.5).astype(int)
116
+ image_f1 = f1_score(y_true_image, y_pred_image)
117
+
118
+ # Pixel-level Metrics
119
+ pixel_ap = average_precision_score(y_true_pixel.flatten(), y_scores_pixel.flatten())
120
+ pixel_auroc = roc_auc_score(y_true_pixel.flatten(), y_scores_pixel.flatten())
121
+ pixel_aupro = compute_aupro(y_true_pixel, y_scores_pixel)
122
+ y_pred_pixel = (y_scores_pixel >= 0.5).astype(int)
123
+ pixel_f1 = f1_score(y_true_pixel.flatten(), y_pred_pixel.flatten())
124
+
125
+ # Compute leaderboard_score as a weighted average (example weights)
126
+ # Adjust weights as per your specific requirements
127
+ leaderboard_score = (
128
+ 0.25 * image_auroc +
129
+ 0.25 * image_f1 +
130
+ 0.25 * pixel_auroc +
131
+ 0.25 * pixel_f1
132
+ )
133
+
134
+ metrics = {
135
+ "image_metrics": {
136
+ "image_ap": round(float(image_ap), 4),
137
+ "image_auroc": round(float(image_auroc), 4),
138
+ "image_f1": round(float(image_f1), 4)
139
+ },
140
+ "pixel_metrics": {
141
+ "pixel_ap": round(float(pixel_ap), 4),
142
+ "pixel_aupro": round(float(pixel_aupro), 4),
143
+ "pixel_auroc": round(float(pixel_auroc), 4),
144
+ "pixel_f1": round(float(pixel_f1), 4)
145
+ },
146
+ "overall_metric": {
147
+ "leaderboard_score": round(float(leaderboard_score), 4)
148
+ }
149
+ }
150
+
151
+ return metrics
152
+
153
+
154
+ def get_class_name(image_path, source_dir):
155
+ """
156
+ Extract the class name from the image path.
157
+
158
+ Args:
159
+ image_path (str): Path to the image file.
160
+ source_dir (str): Root source directory.
161
+
162
+ Returns:
163
+ str: Class name.
164
+ """
165
+ # Example image_path: "./data/pill/test/broken/image1.png"
166
+ rel_path = os.path.relpath(image_path, source_dir) # "pill/test/broken/image1.png"
167
+ parts = rel_path.split(os.sep)
168
+ if len(parts) < 2:
169
+ raise ValueError(f"Unexpected image path format: {image_path}")
170
+ class_name = parts[0] # "pill"
171
+ return class_name
172
+
173
+
174
+ def main():
175
+ SEED = 41 # You can choose any integer value
176
+ set_seed(SEED)
177
+ # Configure logging
178
+ logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
179
+
180
+ # Configuration
181
+ source_dir = "./data"
182
+ output_scores_dir = "./output_scores"
183
+ split = DatasetSplit.TEST # Use the Enum instead of string
184
+ device = "cuda:0" if torch.cuda.is_available() else "cpu"
185
+
186
+ logging.info("Initializing the dataset and dataloader...")
187
+
188
+ # Initialize dataset and dataloader using AllClassesDataset with output_size=17
189
+ dataset = AllClassesDataset(
190
+ source=source_dir,
191
+ split=split,
192
+ # output_size=16 # Set to match anomaly_map resolution
193
+ )
194
+ dataloader = DataLoader(dataset, batch_size=4, shuffle=False, num_workers=0)
195
+
196
+ logging.info("Initializing the anomaly detector...")
197
+ # Initialize anomaly detector
198
+ detector = AnomalyDetector(device=device)
199
+
200
+ # Initialize DumpScores
201
+ dump_scores = DumpScores(output_dir=output_scores_dir)
202
+
203
+ logging.info("Starting anomaly detection inference...")
204
+ # Initialize containers for metrics
205
+ classes = dataset.get_all_class_names()
206
+ metrics_data = {cls: {
207
+ "y_true_image": [],
208
+ "y_scores_image": [],
209
+ "y_true_pixel": [],
210
+ "y_scores_pixel": []
211
+ } for cls in classes}
212
+
213
+ # Iterate through the dataset
214
+ for batch_idx, batch in enumerate(dataloader):
215
+ image = batch['image'].squeeze(0) # Shape: [3, H, W]
216
+ mask = batch['mask'].squeeze(1).numpy() # Remove all singleton dimensions to get [17, 17]
217
+ image_label = batch['is_anomaly'].item() # 1 or 0
218
+ image_path = batch['image_path'][0] # Assuming batch_size=1
219
+
220
+ # Extract class name from image_path
221
+ try:
222
+ class_name = get_class_name(image_path, source_dir)
223
+ except ValueError as e:
224
+ logging.error(f"Error extracting class name: {e}")
225
+ continue # Skip this sample
226
+
227
+ # Extract features and compute scores using GLASS
228
+ image_score, anomaly_map = detector.extract_features(image, "all")
229
+
230
+ # Compute pixel-level anomaly score (already normalized)
231
+ pixel_score = detector.compute_pixel_score(anomaly_map).squeeze()
232
+
233
+ pixel_score_tensor = torch.from_numpy(pixel_score).float().unsqueeze(0).unsqueeze(0).to(
234
+ device) # Shape: [1, 1, 17, 17]
235
+
236
+ # **Upsample pixel_score to (224, 224)**
237
+ # Option 1: Using PyTorch Interpolation
238
+ pixel_score = F.interpolate(
239
+ pixel_score_tensor, # Add batch and channel dimensions
240
+ size=(224, 224),
241
+ mode='bilinear',
242
+ align_corners=False
243
+ ).squeeze(0).cpu().numpy() # Removes all singleton dimensions, resulting in [224, 224]
244
+
245
+
246
+ # Option 2: Using OpenCV (Uncomment if preferred)
247
+ # pixel_score_np = pixel_score.numpy()
248
+ # pixel_score = cv2.resize(
249
+ # pixel_score,
250
+ # dsize=(224, 224),
251
+ # interpolation=cv2.INTER_LINEAR
252
+ # )
253
+
254
+ # **Optional: Verify the upsampled pixel_score shape**
255
+ # if pixel_score.shape != (1, 224, 224):
256
+ # logging.warning(
257
+ # f"Upsampled pixel score shape mismatch for image {image_path}: expected (224, 224), got {pixel_score.shape}")
258
+ # continue # Skip this sample
259
+
260
+ # Append to metrics_data
261
+ metrics_data[class_name]["y_true_image"].append(image_label)
262
+ metrics_data[class_name]["y_scores_image"].append(image_score)
263
+ metrics_data[class_name]["y_true_pixel"].append(mask)
264
+ metrics_data[class_name]["y_scores_pixel"].append(pixel_score)
265
+
266
+ # Save individual image scores
267
+ dump_scores.save_scores([image_path], [image_score], [pixel_score])
268
+
269
+ logging.info(f"[{batch_idx + 1}/{len(dataloader)}] Processed image: {image_path}")
270
+ logging.info(f"Image-level score: {image_score:.4f}")
271
+ logging.info(f"Pixel-level mean score: {pixel_score.mean():.4f}")
272
+
273
+ logging.info("Anomaly detection inference completed. Computing metrics...")
274
+
275
+ # Initialize dictionary to hold metrics per class
276
+ classes_metrics = {}
277
+
278
+ for cls in classes:
279
+ y_true_image = np.array(metrics_data[cls]["y_true_image"])
280
+ y_scores_image = np.array(metrics_data[cls]["y_scores_image"])
281
+ y_true_pixel = np.array(metrics_data[cls]["y_true_pixel"])
282
+ y_scores_pixel = np.array(metrics_data[cls]["y_scores_pixel"])
283
+
284
+ # Check if there are any samples for the class
285
+ if len(y_true_image) == 0:
286
+ logging.warning(f"No samples found for class {cls}. Skipping metric computation.")
287
+ continue
288
+
289
+ try:
290
+ metrics = compute_metrics(y_true_image, y_scores_image, y_true_pixel, y_scores_pixel)
291
+ classes_metrics[cls] = metrics
292
+ logging.info(f"Metrics computed for class: {cls}")
293
+ except Exception as e:
294
+ logging.error(f"Failed to compute metrics for class {cls}: {e}")
295
+
296
+ # Save metrics to JSON
297
+ os.makedirs(output_scores_dir, exist_ok=True)
298
+ metrics_json_path = os.path.join(output_scores_dir, "metrics.json")
299
+ try:
300
+ with open(metrics_json_path, "w") as f:
301
+ json.dump(classes_metrics, f, indent=4)
302
+ logging.info(f"Metrics successfully saved to {metrics_json_path}")
303
+ except Exception as e:
304
+ logging.error(f"Failed to save metrics to {metrics_json_path}: {e}")
305
+
306
+
307
+ if __name__ == "__main__":
308
+ main()
models/anomaly_detector.py ADDED
@@ -0,0 +1,186 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # models/anomaly_detector.py
2
+
3
+ import torch
4
+ import torch.nn.functional as F
5
+ import numpy as np
6
+ from .glass import GLASS # Ensure correct import
7
+ import os
8
+ import logging
9
+ from torchvision import models
10
+
11
+ LOGGER = logging.getLogger(__name__)
12
+
13
+ class AnomalyDetector:
14
+ def __init__(self, device='cuda'):
15
+ self.device = torch.device(device if torch.cuda.is_available() else 'cpu')
16
+
17
+ # Initialize the backbone (e.g., ResNet-50) without pretrained weights
18
+ backbone = models.resnet50(pretrained=False)
19
+
20
+ # Load backbone weights from local file
21
+ backbone_weights_path = './backbones/resnet50_backbone.pth' # Update this path as needed
22
+ if os.path.exists(backbone_weights_path):
23
+ LOGGER.info(f"Loading ResNet-50 backbone weights from '{backbone_weights_path}'")
24
+ checkpoint = torch.load(backbone_weights_path, map_location="cpu")
25
+ try:
26
+ backbone.load_state_dict(checkpoint, strict=True)
27
+ LOGGER.info("ResNet-50 backbone weights loaded successfully.")
28
+ except RuntimeError as e:
29
+ LOGGER.error(f"Error loading ResNet-50 backbone state_dict: {e}")
30
+ raise
31
+ else:
32
+ LOGGER.error(f"Backbone weights not found at '{backbone_weights_path}'")
33
+ raise FileNotFoundError(f"Backbone weights not found at '{backbone_weights_path}'")
34
+
35
+ # Initialize the GLASS model
36
+ self.glass = GLASS(device=self.device)
37
+
38
+ # Define parameters for GLASS.load() to match training
39
+ layers_to_extract_from = ['layer4'] # Extract only the last layer
40
+ input_shape = (3, 224, 224) # Match training input shape
41
+ pretrain_embed_dimension = 2048 # Corrected dimension for 'layer4' in ResNet-50
42
+ target_embed_dimension = 1024 # Match training target dimension
43
+
44
+ # Initialize GLASS with consistent parameters
45
+ self.glass.load(
46
+ backbone=backbone,
47
+ layers_to_extract_from=layers_to_extract_from,
48
+ device=self.device,
49
+ input_shape=input_shape,
50
+ pretrain_embed_dimension=pretrain_embed_dimension,
51
+ target_embed_dimension=target_embed_dimension,
52
+ patchsize=3,
53
+ patchstride=1,
54
+ meta_epochs=640, # Not relevant for inference but required by load method
55
+ eval_epochs=1,
56
+ dsc_layers=2,
57
+ dsc_hidden=1024,
58
+ dsc_margin=0.5,
59
+ train_backbone=False,
60
+ pre_proj=1,
61
+ mining=1,
62
+ noise=0.015,
63
+ radius=0.75,
64
+ p=0.5,
65
+ lr=0.0001,
66
+ svd=0,
67
+ step=20,
68
+ limit=392,
69
+ **{}
70
+ )
71
+
72
+ # Set model directories
73
+ model_dir = "./models" # Base directory for models
74
+ dataset_name = "rayan_dataset" # Example dataset name
75
+ self.glass.set_model_dir(model_dir, dataset_name)
76
+
77
+ self.glass.to(self.device)
78
+ self.glass.eval() # Set GLASS to evaluation mode
79
+
80
+ # Initialize a cache to keep track of loaded classes
81
+ self.loaded_classes = set()
82
+
83
+ def load_model_weights(self, model_dir, classname):
84
+ """
85
+ Load the saved model weights for a specific class.
86
+
87
+ Args:
88
+ model_dir (str): Base directory where models are saved.
89
+ classname (str): The class name whose model weights to load.
90
+ """
91
+ checkpoint_path = os.path.join(model_dir, classname, f"best_model_{classname}.pth")
92
+ if os.path.exists(checkpoint_path):
93
+ LOGGER.info(f"Loading model weights from '{checkpoint_path}' for class '{classname}'")
94
+ checkpoint = torch.load(checkpoint_path, map_location=self.device)
95
+ try:
96
+ self.glass.load_state_dict(checkpoint, strict=True)
97
+ LOGGER.info(f"Model weights loaded successfully for class '{classname}'")
98
+ except RuntimeError as e:
99
+ LOGGER.error(f"Error loading state_dict for class '{classname}': {e}")
100
+ raise
101
+ else:
102
+ LOGGER.error(f"Checkpoint not found at '{checkpoint_path}' for class '{classname}'")
103
+ raise FileNotFoundError(f"Checkpoint not found at '{checkpoint_path}' for class '{classname}'")
104
+
105
+ def extract_features(self, image, classname):
106
+ """
107
+ Use GLASS to extract features and generate anomaly scores for a specific class.
108
+
109
+ Args:
110
+ image (torch.Tensor): Image tensor of shape [3, H, W]
111
+ classname (str): The class name for which to perform anomaly detection.
112
+
113
+ Returns:
114
+ tuple: (image_score, anomaly_map)
115
+ """
116
+
117
+ # Load model weights for classname if not already loaded
118
+ # if classname not in self.loaded_classes:
119
+ # try:
120
+ # self.load_model_weights(model_dir="./models", classname=classname)
121
+ # self.loaded_classes.add(classname)
122
+ # except FileNotFoundError as e:
123
+ # LOGGER.error(f"Failed to load model weights for class '{classname}': {e}")
124
+ # raise
125
+
126
+ # Reshape image to include batch dimension
127
+ image = image.unsqueeze(0).to(self.device) # Shape: [1, 3, H, W]
128
+
129
+ # Use GLASS to get embeddings
130
+ with torch.no_grad():
131
+ patch_features, patch_shapes = self.glass._embed(image, evaluation=True)
132
+ if self.glass.pre_proj > 0:
133
+ patch_features = self.glass.pre_projection(patch_features)
134
+ # Handle if pre_projection returns multiple outputs
135
+ if isinstance(patch_features, tuple) or isinstance(patch_features, list):
136
+ patch_features = patch_features[0]
137
+
138
+ # Pass through discriminator to get anomaly scores
139
+ patch_scores = self.glass.discriminator(patch_features)
140
+ patch_scores = self.glass.patch_maker.unpatch_scores(patch_scores, batchsize=image.shape[0])
141
+
142
+ # Select the last layer's patch_shapes (only one layer now)
143
+ last_patch_shape = patch_shapes[-1] # Should be [17, 17]
144
+
145
+ # Ensure that last_patch_shape is a list or tuple of two integers
146
+ if isinstance(last_patch_shape, (list, tuple)) and len(last_patch_shape) == 2:
147
+ # Reshape patch_scores to [batch_size, H_patches, W_patches]
148
+ # First, squeeze the last dimension
149
+ patch_scores = patch_scores.squeeze(-1) # Shape: [1, 289]
150
+
151
+ # Reshape to [1, 17, 17]
152
+ patch_scores = patch_scores.reshape(image.shape[0], *last_patch_shape) # [1, 17, 17]
153
+ else:
154
+ LOGGER.error(f"Unexpected patch_shapes format: {patch_shapes}")
155
+ raise ValueError(f"Unexpected patch_shapes format: {patch_shapes}")
156
+
157
+ # Compute image-level score (example: mean of patch scores)
158
+ image_score = patch_scores.mean().item()
159
+
160
+ # Anomaly map is the patch_scores itself, normalized
161
+ anomaly_map = patch_scores.cpu().numpy()
162
+ anomaly_map = np.clip(anomaly_map, 0, 1)
163
+
164
+ # Log anomaly map statistics for debugging
165
+ LOGGER.info(f"Anomaly map stats for class '{classname}': min={anomaly_map.min():.4f}, max={anomaly_map.max():.4f}, mean={anomaly_map.mean():.4f}")
166
+
167
+ return image_score, anomaly_map
168
+
169
+ def compute_pixel_score(self, anomaly_map):
170
+ """
171
+ Processes the anomaly map for pixel-level evaluation.
172
+
173
+ Args:
174
+ anomaly_map (np.ndarray): Anomaly map of shape [17, 17]
175
+
176
+ Returns:
177
+ np.ndarray: Processed anomaly map of shape [17, 17]
178
+ """
179
+ # Normalize anomaly_map to [0, 1]
180
+ min_val = anomaly_map.min()
181
+ max_val = anomaly_map.max()
182
+ if max_val - min_val < 1e-8:
183
+ LOGGER.warning("Anomaly map has zero variance. Returning zero map.")
184
+ return np.zeros_like(anomaly_map)
185
+ anomaly_map = (anomaly_map - min_val) / (max_val - min_val + 1e-8)
186
+ return anomaly_map
models/common.py ADDED
@@ -0,0 +1,154 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # common.py
2
+
3
+ import copy
4
+ import numpy as np
5
+ import scipy.ndimage as ndimage
6
+ import torch
7
+ import torch.nn.functional as F
8
+ from torch import nn
9
+
10
+
11
+ class Preprocessing(torch.nn.Module):
12
+ def __init__(self, input_dims, output_dim):
13
+ super(Preprocessing, self).__init__()
14
+ self.input_dims = input_dims
15
+ self.output_dim = output_dim
16
+
17
+ self.preprocessing_modules = torch.nn.ModuleList()
18
+ for _ in input_dims:
19
+ module = MeanMapper(output_dim)
20
+ self.preprocessing_modules.append(module)
21
+
22
+ def forward(self, features):
23
+ _features = []
24
+ for module, feature in zip(self.preprocessing_modules, features):
25
+ _features.append(module(feature))
26
+ return torch.stack(_features, dim=1)
27
+
28
+
29
+ class MeanMapper(torch.nn.Module):
30
+ def __init__(self, preprocessing_dim):
31
+ super(MeanMapper, self).__init__()
32
+ self.preprocessing_dim = preprocessing_dim
33
+
34
+ def forward(self, features):
35
+ features = features.reshape(len(features), 1, -1)
36
+ return F.adaptive_avg_pool1d(features, self.preprocessing_dim).squeeze(1)
37
+
38
+
39
+ class Aggregator(torch.nn.Module):
40
+ def __init__(self, target_dim):
41
+ super(Aggregator, self).__init__()
42
+ self.target_dim = target_dim
43
+
44
+ def forward(self, features):
45
+ """Returns reshaped and average pooled features."""
46
+ features = features.reshape(len(features), 1, -1)
47
+ features = F.adaptive_avg_pool1d(features, self.target_dim)
48
+ return features.reshape(len(features), -1)
49
+
50
+
51
+ class RescaleSegmentor:
52
+ def __init__(self, device, target_size=288):
53
+ self.device = device
54
+ self.target_size = target_size
55
+ self.smoothing = 4
56
+
57
+ def convert_to_segmentation(self, patch_scores):
58
+ with torch.no_grad():
59
+ if isinstance(patch_scores, np.ndarray):
60
+ patch_scores = torch.from_numpy(patch_scores)
61
+ _scores = patch_scores.to(self.device)
62
+ _scores = _scores.unsqueeze(1)
63
+ _scores = F.interpolate(
64
+ _scores, size=self.target_size, mode="bilinear", align_corners=False
65
+ )
66
+ _scores = _scores.squeeze(1)
67
+ patch_scores = _scores.cpu().numpy()
68
+ return [ndimage.gaussian_filter(patch_score, sigma=self.smoothing) for patch_score in patch_scores]
69
+
70
+
71
+ class NetworkFeatureAggregator(torch.nn.Module):
72
+ """Efficient extraction of network features."""
73
+
74
+ def __init__(self, backbone, layers_to_extract_from, device, train_backbone=False):
75
+ super(NetworkFeatureAggregator, self).__init__()
76
+ """Extraction of network features.
77
+
78
+ Runs a network only to the last layer of the list of layers where
79
+ network features should be extracted from.
80
+
81
+ Args:
82
+ backbone: torchvision.model
83
+ layers_to_extract_from: [list of str]
84
+ """
85
+ self.layers_to_extract_from = layers_to_extract_from
86
+ self.backbone = backbone
87
+ self.device = device
88
+ self.train_backbone = train_backbone
89
+ if not hasattr(backbone, "hook_handles"):
90
+ self.backbone.hook_handles = []
91
+ for handle in self.backbone.hook_handles:
92
+ handle.remove()
93
+ self.outputs = {}
94
+
95
+ for extract_layer in layers_to_extract_from:
96
+ self.register_hook(extract_layer)
97
+
98
+ self.to(self.device)
99
+
100
+ def forward(self, images, eval=True):
101
+ self.outputs.clear()
102
+ if self.train_backbone and not eval:
103
+ self.backbone.train()
104
+ self.backbone(images)
105
+ else:
106
+ self.backbone.eval()
107
+ with torch.no_grad():
108
+ self.backbone(images)
109
+ return self.outputs
110
+
111
+ def feature_dimensions(self, input_shape):
112
+ """Computes the feature dimensions for all layers given input_shape."""
113
+ _input = torch.ones([1] + list(input_shape)).to(self.device)
114
+ _output = self(_input)
115
+ return [_output[layer].shape[1] for layer in self.layers_to_extract_from]
116
+
117
+ def register_hook(self, layer_name):
118
+ module = self.find_module(self.backbone, layer_name)
119
+ if module is not None:
120
+ forward_hook = ForwardHook(self.outputs, layer_name, self.layers_to_extract_from[-1])
121
+ if isinstance(module, torch.nn.Sequential):
122
+ hook = module[-1].register_forward_hook(forward_hook)
123
+ else:
124
+ hook = module.register_forward_hook(forward_hook)
125
+ self.backbone.hook_handles.append(hook)
126
+ else:
127
+ raise ValueError(f"Module {layer_name} not found in the model")
128
+
129
+ def find_module(self, model, module_name):
130
+ for name, module in model.named_modules():
131
+ if name == module_name:
132
+ return module
133
+ elif '.' in module_name:
134
+ father, child = module_name.split('.', 1)
135
+ if name == father:
136
+ return self.find_module(module, child)
137
+ return None
138
+
139
+
140
+ class ForwardHook:
141
+ def __init__(self, hook_dict, layer_name: str, last_layer_to_extract: str):
142
+ self.hook_dict = hook_dict
143
+ self.layer_name = layer_name
144
+ self.raise_exception_to_break = copy.deepcopy(
145
+ layer_name == last_layer_to_extract
146
+ )
147
+
148
+ def __call__(self, module, input, output):
149
+ self.hook_dict[self.layer_name] = output
150
+ return None
151
+
152
+
153
+ class LastLayerToExtractReachedException(Exception):
154
+ pass
models/glass.py ADDED
@@ -0,0 +1,372 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # models/glass.py
2
+
3
+ import logging
4
+ import math
5
+ import os
6
+ import torch
7
+ import numpy as np
8
+ import torch.nn as nn
9
+ from torch.cuda.amp import GradScaler, autocast
10
+ from .common import NetworkFeatureAggregator, Preprocessing, MeanMapper, Aggregator, RescaleSegmentor, ForwardHook
11
+ import torch.nn.functional as F
12
+ from torch.utils.tensorboard import SummaryWriter
13
+ import torch.optim as optim
14
+ from .model import Discriminator, Projection, PatchMaker
15
+
16
+ LOGGER = logging.getLogger(__name__)
17
+ IMAGENET_MEAN = [0.485, 0.456, 0.406]
18
+ IMAGENET_STD = [0.229, 0.224, 0.225]
19
+
20
+
21
+ class TBWrapper:
22
+ def __init__(self, log_dir):
23
+ self.g_iter = 0
24
+ self.logger = SummaryWriter(log_dir=log_dir)
25
+
26
+ def step(self):
27
+ self.g_iter += 1
28
+
29
+ def log(self, tag, value, step):
30
+ self.logger.add_scalar(tag, value, step)
31
+
32
+
33
+ class GLASS(torch.nn.Module):
34
+ def __init__(self, device):
35
+ super(GLASS, self).__init__()
36
+ self.device = device
37
+
38
+ def load(
39
+ self,
40
+ backbone,
41
+ layers_to_extract_from,
42
+ device,
43
+ input_shape,
44
+ pretrain_embed_dimension,
45
+ target_embed_dimension,
46
+ patchsize=3,
47
+ patchstride=1,
48
+ meta_epochs=640,
49
+ eval_epochs=1,
50
+ dsc_layers=2,
51
+ dsc_hidden=1024,
52
+ dsc_margin=0.5,
53
+ train_backbone=False, # Changed to be set externally
54
+ pre_proj=1,
55
+ mining=1,
56
+ noise=0.015,
57
+ radius=0.75,
58
+ p=0.5,
59
+ lr=0.0001,
60
+ svd=0,
61
+ step=20,
62
+ limit=392,
63
+ **kwargs,
64
+ ):
65
+
66
+ self.backbone = backbone.to(device)
67
+ self.layers_to_extract_from = layers_to_extract_from
68
+ self.input_shape = input_shape
69
+ self.device = device
70
+
71
+ self.forward_modules = torch.nn.ModuleDict({})
72
+ feature_aggregator = NetworkFeatureAggregator(
73
+ self.backbone, self.layers_to_extract_from, self.device, train_backbone
74
+ )
75
+ feature_dimensions = feature_aggregator.feature_dimensions(input_shape)
76
+ self.forward_modules["feature_aggregator"] = feature_aggregator
77
+
78
+ preprocessing = Preprocessing(feature_dimensions, pretrain_embed_dimension)
79
+ self.forward_modules["preprocessing"] = preprocessing
80
+ self.target_embed_dimension = target_embed_dimension
81
+ preadapt_aggregator = Aggregator(target_dim=target_embed_dimension)
82
+ preadapt_aggregator.to(self.device)
83
+ self.forward_modules["preadapt_aggregator"] = preadapt_aggregator
84
+
85
+ self.meta_epochs = meta_epochs
86
+ self.lr = lr
87
+ self.train_backbone = train_backbone
88
+ if self.train_backbone:
89
+ self.backbone_opt = torch.optim.AdamW(self.forward_modules["feature_aggregator"].backbone.parameters(), lr)
90
+
91
+ self.pre_proj = pre_proj
92
+ if self.pre_proj > 0:
93
+ self.pre_projection = Projection(self.target_embed_dimension, self.target_embed_dimension, pre_proj)
94
+ self.pre_projection.to(self.device)
95
+ self.proj_opt = torch.optim.Adam(self.pre_projection.parameters(), lr, weight_decay=1e-5)
96
+
97
+ self.eval_epochs = eval_epochs
98
+ self.dsc_layers = dsc_layers
99
+ self.dsc_hidden = dsc_hidden
100
+ self.discriminator = Discriminator(self.target_embed_dimension, n_layers=dsc_layers, hidden=dsc_hidden)
101
+ self.discriminator.to(self.device)
102
+ self.dsc_opt = torch.optim.AdamW(self.discriminator.parameters(), lr=lr * 2)
103
+ self.dsc_margin = dsc_margin
104
+
105
+ self.c = torch.tensor(0)
106
+ self.c_ = torch.tensor(0)
107
+ self.p = p
108
+ self.radius = radius
109
+ self.mining = mining
110
+ self.noise = noise
111
+ self.svd = svd
112
+ self.step = step
113
+ self.limit = limit
114
+ self.distribution = 0
115
+
116
+ # Replace FocalLoss with MSELoss
117
+ self.loss_fn = nn.MSELoss()
118
+
119
+ self.patch_maker = PatchMaker(patchsize, stride=patchstride)
120
+ self.anomaly_segmentor = RescaleSegmentor(device=self.device, target_size=input_shape[-2:])
121
+ self.model_dir = ""
122
+ self.dataset_name = ""
123
+ self.logger = None
124
+
125
+ def set_model_dir(self, model_dir, dataset_name):
126
+ self.model_dir = model_dir
127
+ os.makedirs(self.model_dir, exist_ok=True)
128
+ self.ckpt_dir = os.path.join(self.model_dir, dataset_name)
129
+ os.makedirs(self.ckpt_dir, exist_ok=True)
130
+ self.tb_dir = os.path.join(self.ckpt_dir, "tb")
131
+ os.makedirs(self.tb_dir, exist_ok=True)
132
+ self.logger = TBWrapper(self.tb_dir)
133
+
134
+ def _embed(self, images, detach=True, provide_patch_shapes=False, evaluation=False):
135
+ """Returns feature embeddings for images."""
136
+ images = images.float() # Ensure input tensor is float32
137
+ if not evaluation and self.train_backbone:
138
+ self.forward_modules["feature_aggregator"].train()
139
+ features = self.forward_modules["feature_aggregator"](images, eval=evaluation)
140
+ else:
141
+ self.forward_modules["feature_aggregator"].eval()
142
+ with torch.no_grad():
143
+ features = self.forward_modules["feature_aggregator"](images)
144
+
145
+ features = [features[layer] for layer in self.layers_to_extract_from]
146
+
147
+ for i, feat in enumerate(features):
148
+ if len(feat.shape) == 3:
149
+ B, L, C = feat.shape
150
+ sqrt_L = int(math.sqrt(L))
151
+ if sqrt_L * sqrt_L != L:
152
+ raise ValueError(f"Layer {self.layers_to_extract_from[i]} output has non-square spatial dimensions: {feat.shape}")
153
+ features[i] = feat.reshape(B, sqrt_L, sqrt_L, C).permute(0, 3, 1, 2)
154
+ # Debug statement
155
+ assert features[i].requires_grad, f"Feature {i} from layer {self.layers_to_extract_from[i]} does not require grad."
156
+
157
+ features = [self.patch_maker.patchify(x, return_spatial_info=True) for x in features]
158
+ patch_shapes = [x[1] for x in features]
159
+ patch_features = [x[0] for x in features]
160
+ ref_num_patches = patch_shapes[0]
161
+
162
+ for i in range(1, len(patch_features)):
163
+ _features = patch_features[i]
164
+ patch_dims = patch_shapes[i]
165
+
166
+ _features = _features.reshape(
167
+ _features.shape[0], patch_dims[0], patch_dims[1], *_features.shape[2:]
168
+ )
169
+ _features = _features.permute(0, 3, 4, 5, 1, 2)
170
+ perm_base_shape = _features.shape
171
+ _features = _features.reshape(-1, *_features.shape[-2:])
172
+ _features = F.interpolate(
173
+ _features.unsqueeze(1),
174
+ size=(ref_num_patches[0], ref_num_patches[1]),
175
+ mode="bilinear",
176
+ align_corners=False,
177
+ )
178
+ _features = _features.squeeze(1)
179
+ _features = _features.reshape(
180
+ *perm_base_shape[:-2], ref_num_patches[0], ref_num_patches[1]
181
+ )
182
+ _features = _features.permute(0, 4, 5, 1, 2, 3)
183
+ _features = _features.reshape(len(_features), -1, *_features.shape[-3:])
184
+ patch_features[i] = _features
185
+
186
+ patch_features = [x.reshape(-1, *x.shape[-3:]) for x in patch_features]
187
+ patch_features = self.forward_modules["preprocessing"](patch_features)
188
+ patch_features = self.forward_modules["preadapt_aggregator"](patch_features)
189
+
190
+ return patch_features, patch_shapes
191
+
192
+ def trainer(self, training_data, val_data, name):
193
+ """
194
+ Training loop for the GLASS model.
195
+
196
+ Args:
197
+ training_data (DataLoader): DataLoader for the training dataset.
198
+ val_data (DataLoader): DataLoader for the validation dataset.
199
+ name (str): Name identifier for the training run.
200
+ """
201
+ self.train()
202
+ self.discriminator.train()
203
+
204
+ # Initialize optimizers
205
+ optimizer = optim.AdamW(self.forward_modules.parameters(), lr=self.lr)
206
+ optimizer_d = optim.AdamW(self.discriminator.parameters(), lr=self.lr * 2)
207
+
208
+ # Initialize loss functions
209
+ criterion_d = nn.BCEWithLogitsLoss()
210
+
211
+ # Initialize separate AMP scalers
212
+ scaler_main = GradScaler()
213
+ scaler_dsc = GradScaler()
214
+
215
+ # Initialize TensorBoard writer
216
+ if self.logger is not None:
217
+ tb_writer = self.logger
218
+ else:
219
+ tb_writer = SummaryWriter()
220
+
221
+ best_auroc = 0.0
222
+ best_model_path = os.path.join(self.model_dir, f"best_model_{name}.pth")
223
+
224
+ for epoch in range(1, self.meta_epochs + 1):
225
+ LOGGER.info(f"Epoch [{epoch}/{self.meta_epochs}]")
226
+ epoch_loss = 0.0
227
+ epoch_loss_d = 0.0
228
+ for batch_idx, batch in enumerate(training_data):
229
+ images = batch['image'].to(self.device).float() # [B, 3, H, W]
230
+ aug_images = batch['aug'].to(self.device).float() # [B, 3, H, W]
231
+ masks_s = batch['mask_s'].to(self.device).float() # [B, H, W]
232
+ masks_gt = batch['mask_gt'].to(self.device).float() # [B, 1, H, W]
233
+
234
+ optimizer.zero_grad()
235
+ optimizer_d.zero_grad()
236
+
237
+ # ----- Train Main Model -----
238
+ with autocast():
239
+ # Forward pass
240
+ embeddings, _ = self._embed(images) # [B*N_patches, D]
241
+ aug_embeddings, _ = self._embed(aug_images) # [B*N_patches, D]
242
+
243
+ # Aggregate embeddings to [B, D] by averaging over patches
244
+ B = images.size(0)
245
+ N_patches = embeddings.size(0) // B
246
+ assert embeddings.size(
247
+ 0) == B * N_patches, "Embeddings cannot be evenly divided into the batch size."
248
+ embeddings = embeddings.view(B, N_patches, -1).mean(dim=1) # [B, D]
249
+ aug_embeddings = aug_embeddings.view(B, N_patches, -1).mean(dim=1) # [B, D]
250
+
251
+ # Debug tensor properties
252
+ assert embeddings.requires_grad, "Embeddings do not require grad!"
253
+ assert aug_embeddings.requires_grad, "Augmented embeddings do not require grad!"
254
+ assert embeddings.shape[0] == images.size(
255
+ 0), "Aggregated embeddings batch size does not match input batch size."
256
+
257
+ # Compute reconstruction or similarity loss
258
+ loss = self.loss_fn(embeddings, aug_embeddings)
259
+
260
+ # Backward pass with AMP for main model
261
+ scaler_main.scale(loss).backward()
262
+ scaler_main.step(optimizer)
263
+ scaler_main.update()
264
+
265
+ epoch_loss += loss.item()
266
+
267
+ # ----- Train Discriminator -----
268
+ with autocast():
269
+ # Detach embeddings to prevent gradients flowing back to the main model
270
+ embeddings_detached = embeddings.detach()
271
+ aug_embeddings_detached = aug_embeddings.detach()
272
+
273
+ # Discriminator forward pass
274
+ outputs_real = self.discriminator(embeddings_detached) # [B, 1]
275
+ outputs_fake = self.discriminator(aug_embeddings_detached) # [B, 1]
276
+
277
+ # Create labels
278
+ real_labels = torch.ones(outputs_real.size(0), 1).to(self.device) # [B, 1]
279
+ fake_labels = torch.zeros(outputs_fake.size(0), 1).to(self.device) # [B, 1]
280
+
281
+ # Compute discriminator loss
282
+ loss_real = criterion_d(outputs_real, real_labels)
283
+ loss_fake = criterion_d(outputs_fake, fake_labels)
284
+ loss_d = loss_real + loss_fake
285
+
286
+ # Backward pass with AMP for discriminator
287
+ scaler_dsc.scale(loss_d).backward()
288
+ scaler_dsc.step(optimizer_d)
289
+ scaler_dsc.update()
290
+
291
+ epoch_loss_d += loss_d.item()
292
+
293
+ if batch_idx % 100 == 0:
294
+ LOGGER.info(f"Batch [{batch_idx}/{len(training_data)}] "
295
+ f"Loss: {loss.item():.4f} Loss_D: {loss_d.item():.4f}")
296
+
297
+ avg_epoch_loss = epoch_loss / len(training_data)
298
+ avg_epoch_loss_d = epoch_loss_d / len(training_data)
299
+
300
+ LOGGER.info(f"Epoch [{epoch}/{self.meta_epochs}] "
301
+ f"Average Loss: {avg_epoch_loss:.4f} "
302
+ f"Average Loss_D: {avg_epoch_loss_d:.4f}")
303
+
304
+ # Log to TensorBoard
305
+ tb_writer.log("Train/Loss", avg_epoch_loss, epoch)
306
+ tb_writer.log("Train/Loss_D", avg_epoch_loss_d, epoch)
307
+
308
+ # Validation
309
+ if epoch % self.eval_epochs == 0:
310
+ auroc = self.tester(val_data, name)
311
+ LOGGER.info(f"Validation AUROC after Epoch [{epoch}]: {auroc:.4f}")
312
+ tb_writer.log("Validation/AUROC", auroc, epoch)
313
+
314
+ # Save the best model
315
+ if auroc > best_auroc:
316
+ best_auroc = auroc
317
+ torch.save(self.state_dict(), best_model_path) # Save only state_dict
318
+ LOGGER.info(f"Best model saved at Epoch [{epoch}] with AUROC: {auroc:.4f}")
319
+
320
+ LOGGER.info(f"Training completed. Best AUROC: {best_auroc:.4f}")
321
+ tb_writer.close()
322
+
323
+ def tester(self, test_data, name):
324
+ """
325
+ Evaluation loop for the GLASS model.
326
+
327
+ Args:
328
+ test_data (DataLoader): DataLoader for the test dataset.
329
+ name (str): Name identifier for the evaluation run.
330
+
331
+ Returns:
332
+ float: AUROC score on the test dataset.
333
+ """
334
+ self.eval()
335
+ self.discriminator.eval()
336
+ all_scores = []
337
+ all_labels = []
338
+
339
+ with torch.no_grad():
340
+ for batch_idx, batch in enumerate(test_data):
341
+ images = batch['image'].to(self.device).float() # [B, 3, H, W]
342
+ masks_gt = batch['mask_gt'].to(self.device).float() # [B, 1, H, W]
343
+ labels = batch['is_anomaly'].cpu().numpy() # [B]
344
+
345
+ # Forward pass
346
+ embeddings, _ = self._embed(images, evaluation=True) # [B*N_patches, D]
347
+ B = images.size(0)
348
+ N_patches = embeddings.size(0) // B
349
+ embeddings = embeddings.view(B, N_patches, -1).mean(dim=1) # [B, D]
350
+ anomaly_scores = self.discriminator(embeddings).cpu().numpy().flatten() # [B]
351
+
352
+ all_scores.extend(anomaly_scores.tolist())
353
+ all_labels.extend(labels.tolist())
354
+
355
+ # Compute AUROC
356
+ from sklearn.metrics import roc_auc_score
357
+ auroc = roc_auc_score(all_labels, all_scores)
358
+ return auroc
359
+
360
+ def _evaluate(self, images, scores, segmentations, labels_gt, masks_gt, name, path='training'):
361
+ # Implementation of evaluation metrics
362
+ pass
363
+
364
+ def predict(self, test_dataloader):
365
+ """This function provides anomaly scores/maps for full dataloaders."""
366
+ # Implementation of prediction logic
367
+ pass
368
+
369
+ def _predict(self, img):
370
+ """Infer score and mask for a batch of images."""
371
+ # Implementation of individual prediction logic
372
+ pass
models/model.py ADDED
@@ -0,0 +1,101 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # models/model.py
2
+
3
+ import torch
4
+
5
+ def init_weight(m):
6
+ if isinstance(m, torch.nn.Linear):
7
+ torch.nn.init.xavier_normal_(m.weight)
8
+ if isinstance(m, torch.nn.BatchNorm2d):
9
+ m.weight.data.normal_(1.0, 0.02)
10
+ m.bias.data.fill_(0)
11
+ elif isinstance(m, torch.nn.Conv2d):
12
+ m.weight.data.normal_(0.0, 0.02)
13
+
14
+
15
+ class Discriminator(torch.nn.Module):
16
+ def __init__(self, in_planes, n_layers=2, hidden=None):
17
+ super(Discriminator, self).__init__()
18
+
19
+ _hidden = in_planes if hidden is None else hidden
20
+ self.body = torch.nn.Sequential()
21
+ for i in range(n_layers - 1):
22
+ _in = in_planes if i == 0 else _hidden
23
+ _hidden = int(_hidden // 1.5) if hidden is None else hidden
24
+ self.body.add_module('block%d' % (i + 1),
25
+ torch.nn.Sequential(
26
+ torch.nn.Linear(_in, _hidden),
27
+ torch.nn.BatchNorm1d(_hidden),
28
+ torch.nn.LeakyReLU(0.2)
29
+ ))
30
+ self.tail = torch.nn.Sequential(
31
+ torch.nn.Linear(_hidden, 1, bias=False),
32
+ torch.nn.Sigmoid()
33
+ )
34
+ self.apply(init_weight)
35
+
36
+ def forward(self, x):
37
+ x = self.body(x)
38
+ x = self.tail(x)
39
+ return x
40
+
41
+
42
+ class Projection(torch.nn.Module):
43
+ def __init__(self, in_planes, out_planes=None, n_layers=1, layer_type=0):
44
+ super(Projection, self).__init__()
45
+
46
+ if out_planes is None:
47
+ out_planes = in_planes
48
+ self.layers = torch.nn.Sequential()
49
+ _in = None
50
+ _out = None
51
+ for i in range(n_layers):
52
+ _in = in_planes if i == 0 else _out
53
+ _out = out_planes
54
+ self.layers.add_module(f"{i}fc", torch.nn.Linear(_in, _out))
55
+ if i < n_layers - 1:
56
+ if layer_type > 1:
57
+ self.layers.add_module(f"{i}relu", torch.nn.LeakyReLU(.2))
58
+ self.apply(init_weight)
59
+
60
+ def forward(self, x):
61
+ x = self.layers(x)
62
+ return x
63
+
64
+
65
+ class PatchMaker:
66
+ def __init__(self, patchsize, top_k=0, stride=None):
67
+ self.patchsize = patchsize
68
+ self.stride = stride
69
+ self.top_k = top_k
70
+
71
+ def patchify(self, features, return_spatial_info=False):
72
+ """Convert a tensor into a tensor of respective patches.
73
+ Args:
74
+ x: [torch.Tensor, bs x c x w x h]
75
+ Returns:
76
+ x: [torch.Tensor, bs * w//stride * h//stride, c, patchsize,
77
+ patchsize]
78
+ """
79
+ padding = int((self.patchsize - 1) / 2)
80
+ unfolder = torch.nn.Unfold(kernel_size=self.patchsize, stride=self.stride, padding=padding, dilation=1)
81
+ unfolded_features = unfolder(features)
82
+ number_of_total_patches = []
83
+ for s in features.shape[-2:]:
84
+ n_patches = (s + 2 * padding - 1 * (self.patchsize - 1) - 1) / self.stride + 1
85
+ number_of_total_patches.append(int(n_patches))
86
+ unfolded_features = unfolded_features.reshape(
87
+ *features.shape[:2], self.patchsize, self.patchsize, -1
88
+ )
89
+ unfolded_features = unfolded_features.permute(0, 4, 1, 2, 3)
90
+
91
+ if return_spatial_info:
92
+ return unfolded_features, number_of_total_patches
93
+ return unfolded_features
94
+
95
+ def unpatch_scores(self, x, batchsize):
96
+ return x.reshape(batchsize, -1, *x.shape[1:])
97
+
98
+ def score(self, x):
99
+ x = x[:, :, 0]
100
+ x = torch.max(x, dim=1).values
101
+ return x
requirements.txt ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ torch==2.4.1
2
+ torchvision==0.16.1
3
+ numpy==1.23.5
4
+ Pillow==9.4.0
5
+ tqdm==4.65.0
6
+ scikit-image==0.20.0
7
+ scikit-learn==1.2.2
8
+ scipy==1.11.4
run.sh ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ #!/bin/sh
2
+ python3 main.py
runner.py ADDED
@@ -0,0 +1,37 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -----------------------------------------------------------------------------
2
+ # This Python script is the primary entry point called by our judge. It runs
3
+ # your code to generate anomaly scores, then evaluates those scores to produce
4
+ # the final results.
5
+ # -----------------------------------------------------------------------------
6
+
7
+ import subprocess
8
+
9
+ # Step 1: Generate anomaly scores
10
+ subprocess.run(["./run.sh"], check=True)
11
+
12
+ # Step 2: Evaluate the generated scores
13
+ subprocess.run(
14
+ [
15
+ "python3",
16
+ "evaluation/eval_main.py",
17
+ "--device",
18
+ "0",
19
+ "--data_path",
20
+ "./data/",
21
+ "--dataset_name",
22
+ "rayan_dataset",
23
+ "--class_name",
24
+ "all",
25
+ "--output_dir",
26
+ "./output",
27
+ "--output_scores_dir",
28
+ "./output_scores",
29
+ "--save_csv",
30
+ "True",
31
+ "--save_json",
32
+ "True",
33
+ "--class_name_mapping_dir",
34
+ "./evaluation/class_name_mapping.json",
35
+ ],
36
+ check=True,
37
+ )
utils/dump_scores.py ADDED
@@ -0,0 +1,34 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # utils/dump_scores.py
2
+
3
+ import os
4
+ import json
5
+ from pathlib import Path
6
+
7
+ class DumpScores:
8
+ def __init__(self, output_dir):
9
+ self.output_dir = output_dir
10
+ os.makedirs(self.output_dir, exist_ok=True)
11
+
12
+ def save_scores(self, image_paths, img_level_scores, pix_level_scores):
13
+ for img_path, img_score, pix_score in zip(image_paths, img_level_scores, pix_level_scores):
14
+ # Determine the relative path to maintain directory structure
15
+ relative_path = os.path.relpath(img_path, "./data")
16
+ relative_dir = os.path.dirname(relative_path)
17
+ output_dir = os.path.join(self.output_dir, relative_dir)
18
+ os.makedirs(output_dir, exist_ok=True)
19
+
20
+ # Get the image filename without extension
21
+ img_name = Path(img_path).stem
22
+
23
+ # Create the JSON structure
24
+ score_data = {
25
+ "img_level_score": img_score,
26
+ "pix_level_score": pix_score.tolist() # Convert numpy array to list for JSON serialization
27
+ }
28
+
29
+ # Define the output JSON file path
30
+ json_path = os.path.join(output_dir, f"{img_name}_scores.json")
31
+
32
+ # Save the JSON file
33
+ with open(json_path, "w") as f:
34
+ json.dump(score_data, f, indent=4)
utils/feature_extractor.py ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # utils/feature_extractor.py
2
+
3
+ import torch
4
+ import torch.nn as nn
5
+ from torchvision import models
6
+
7
+ class FeatureExtractor(nn.Module):
8
+ def __init__(self, backbone='resnet50'):
9
+ super(FeatureExtractor, self).__init__()
10
+ if backbone == 'resnet50':
11
+ self.model = models.resnet50(pretrained=True)
12
+ # Remove the final fully connected layer
13
+ self.features = nn.Sequential(*list(self.model.children())[:-2])
14
+ else:
15
+ raise NotImplementedError(f"Backbone {backbone} is not implemented.")
16
+
17
+ def forward(self, x):
18
+ return self.features(x)