files
Browse files- Dockerfile +14 -0
- datasets/all_classes_dataset.py +158 -0
- datasets/mvec.py +301 -0
- datasets/perlin.py +73 -0
- datasets/rayan_dataset.py +127 -0
- docker-compose.yml +21 -0
- evaluation/base_eval.py +293 -0
- evaluation/class_name_mapping.json +5 -0
- evaluation/eval_main.py +78 -0
- evaluation/json_score.py +98 -0
- evaluation/utils/json_helpers.py +46 -0
- evaluation/utils/metrics.py +111 -0
- main.py +308 -0
- models/anomaly_detector.py +186 -0
- models/common.py +154 -0
- models/glass.py +372 -0
- models/model.py +101 -0
- requirements.txt +8 -0
- run.sh +2 -0
- runner.py +37 -0
- utils/dump_scores.py +34 -0
- utils/feature_extractor.py +18 -0
Dockerfile
ADDED
@@ -0,0 +1,14 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# -----------------------------------------------------------------------------
|
2 |
+
# A sample Dockerfile to help you replicate our test environment
|
3 |
+
# -----------------------------------------------------------------------------
|
4 |
+
|
5 |
+
FROM pytorch/pytorch:2.4.1-cuda12.4-cudnn9-runtime
|
6 |
+
WORKDIR /app
|
7 |
+
COPY . .
|
8 |
+
|
9 |
+
# Install your python and apt requirements
|
10 |
+
RUN pip install -r requirements.txt
|
11 |
+
RUN apt-get update && apt-get install $(cat apt_requirements.txt) -y
|
12 |
+
RUN chmod +x run.sh
|
13 |
+
|
14 |
+
CMD ["python3", "runner.py"]
|
datasets/all_classes_dataset.py
ADDED
@@ -0,0 +1,158 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# datasets/all_classes_dataset.py
|
2 |
+
|
3 |
+
import os
|
4 |
+
from enum import Enum
|
5 |
+
|
6 |
+
import PIL
|
7 |
+
import torch
|
8 |
+
from torch.utils.data import Dataset
|
9 |
+
from torchvision import transforms
|
10 |
+
|
11 |
+
IMAGENET_MEAN = [0.485, 0.456, 0.406]
|
12 |
+
IMAGENET_STD = [0.229, 0.224, 0.225]
|
13 |
+
|
14 |
+
|
15 |
+
class DatasetSplit(Enum):
|
16 |
+
TRAIN = "train"
|
17 |
+
VAL = "val"
|
18 |
+
TEST = "test"
|
19 |
+
|
20 |
+
|
21 |
+
class AllClassesDataset(Dataset):
|
22 |
+
def __init__(
|
23 |
+
self,
|
24 |
+
source,
|
25 |
+
input_size=518,
|
26 |
+
output_size=224,
|
27 |
+
split=DatasetSplit.TEST,
|
28 |
+
external_transform=None,
|
29 |
+
**kwargs,
|
30 |
+
):
|
31 |
+
"""
|
32 |
+
Initialize the dataset to include all classes.
|
33 |
+
|
34 |
+
Args:
|
35 |
+
source (str): Path to the root data directory.
|
36 |
+
input_size (int): Input image size for transformations.
|
37 |
+
output_size (int): Output mask size.
|
38 |
+
split (DatasetSplit): Dataset split to use (TRAIN, VAL, TEST).
|
39 |
+
external_transform (callable, optional): External image transformations.
|
40 |
+
**kwargs: Additional keyword arguments.
|
41 |
+
"""
|
42 |
+
super().__init__()
|
43 |
+
self.source = source
|
44 |
+
self.split = split
|
45 |
+
self.classnames_to_use = self.get_all_class_names()
|
46 |
+
|
47 |
+
self.imgpaths_per_class, self.data_to_iterate = self.get_image_data()
|
48 |
+
|
49 |
+
if external_transform is None:
|
50 |
+
self.transform_img = transforms.Compose([
|
51 |
+
transforms.Resize((input_size, input_size)),
|
52 |
+
# transforms.CenterCrop(input_size),
|
53 |
+
transforms.ToTensor(),
|
54 |
+
transforms.Normalize(mean=IMAGENET_MEAN, std=IMAGENET_STD),
|
55 |
+
])
|
56 |
+
else:
|
57 |
+
self.transform_img = external_transform
|
58 |
+
|
59 |
+
self.transform_mask = transforms.Compose([
|
60 |
+
transforms.Resize((output_size, output_size)),
|
61 |
+
# transforms.CenterCrop(output_size),
|
62 |
+
transforms.ToTensor(),
|
63 |
+
])
|
64 |
+
self.output_shape = (1, output_size, output_size)
|
65 |
+
|
66 |
+
def get_all_class_names(self):
|
67 |
+
"""
|
68 |
+
Retrieve all class names (subdirectories) from the source directory.
|
69 |
+
|
70 |
+
Returns:
|
71 |
+
list: List of class names.
|
72 |
+
"""
|
73 |
+
all_items = os.listdir(self.source)
|
74 |
+
classnames = [
|
75 |
+
item for item in all_items
|
76 |
+
if os.path.isdir(os.path.join(self.source, item))
|
77 |
+
]
|
78 |
+
return classnames
|
79 |
+
|
80 |
+
def get_image_data(self):
|
81 |
+
"""
|
82 |
+
Collect image paths and corresponding mask paths for all classes.
|
83 |
+
|
84 |
+
Returns:
|
85 |
+
tuple: (imgpaths_per_class, data_to_iterate)
|
86 |
+
"""
|
87 |
+
imgpaths_per_class = {}
|
88 |
+
maskpaths_per_class = {}
|
89 |
+
|
90 |
+
for classname in self.classnames_to_use:
|
91 |
+
classpath = os.path.join(self.source, classname, self.split.value)
|
92 |
+
maskpath = os.path.join(self.source, classname, "ground_truth")
|
93 |
+
anomaly_types = os.listdir(classpath)
|
94 |
+
|
95 |
+
imgpaths_per_class[classname] = {}
|
96 |
+
maskpaths_per_class[classname] = {}
|
97 |
+
|
98 |
+
for anomaly in anomaly_types:
|
99 |
+
anomaly_path = os.path.join(classpath, anomaly)
|
100 |
+
anomaly_files = sorted(os.listdir(anomaly_path))
|
101 |
+
imgpaths_per_class[classname][anomaly] = [
|
102 |
+
os.path.join(anomaly_path, x) for x in anomaly_files
|
103 |
+
]
|
104 |
+
|
105 |
+
if self.split == DatasetSplit.TEST and anomaly != "good":
|
106 |
+
anomaly_mask_path = os.path.join(maskpath, anomaly)
|
107 |
+
if os.path.exists(anomaly_mask_path):
|
108 |
+
anomaly_mask_files = sorted(os.listdir(anomaly_mask_path))
|
109 |
+
maskpaths_per_class[classname][anomaly] = [
|
110 |
+
os.path.join(anomaly_mask_path, x) for x in anomaly_mask_files
|
111 |
+
]
|
112 |
+
else:
|
113 |
+
# If mask path does not exist, set to None
|
114 |
+
maskpaths_per_class[classname][anomaly] = [None] * len(anomaly_files)
|
115 |
+
else:
|
116 |
+
maskpaths_per_class[classname]["good"] = [None] * len(anomaly_files)
|
117 |
+
|
118 |
+
data_to_iterate = []
|
119 |
+
for classname in sorted(imgpaths_per_class.keys()):
|
120 |
+
for anomaly in sorted(imgpaths_per_class[classname].keys()):
|
121 |
+
for i, image_path in enumerate(imgpaths_per_class[classname][anomaly]):
|
122 |
+
data_tuple = [classname, anomaly, image_path]
|
123 |
+
if self.split == DatasetSplit.TEST and anomaly != "good":
|
124 |
+
mask_path = maskpaths_per_class[classname][anomaly][i]
|
125 |
+
data_tuple.append(mask_path)
|
126 |
+
else:
|
127 |
+
data_tuple.append(None)
|
128 |
+
data_to_iterate.append(data_tuple)
|
129 |
+
|
130 |
+
return imgpaths_per_class, data_to_iterate
|
131 |
+
|
132 |
+
def __getitem__(self, idx):
|
133 |
+
classname, anomaly, image_path, mask_path = self.data_to_iterate[idx]
|
134 |
+
try:
|
135 |
+
image = PIL.Image.open(image_path).convert("RGB")
|
136 |
+
except Exception as e:
|
137 |
+
# Return a black image or handle as per your requirement
|
138 |
+
image = PIL.Image.new("RGB", (self.transform_img.transforms[0].size, self.transform_img.transforms[0].size), (0, 0, 0))
|
139 |
+
image = self.transform_img(image)
|
140 |
+
|
141 |
+
if self.split == DatasetSplit.TEST and mask_path is not None:
|
142 |
+
try:
|
143 |
+
mask = PIL.Image.open(mask_path).convert("L")
|
144 |
+
mask = self.transform_mask(mask) > 0
|
145 |
+
except Exception as e:
|
146 |
+
mask = torch.zeros([*self.output_shape])
|
147 |
+
else:
|
148 |
+
mask = torch.zeros([*self.output_shape])
|
149 |
+
|
150 |
+
return {
|
151 |
+
"image": image, # Tensor: [3, H, W]
|
152 |
+
"mask": mask, # Tensor: [1, 17, 17]
|
153 |
+
"is_anomaly": int(anomaly != "good"),
|
154 |
+
"image_path": image_path,
|
155 |
+
}
|
156 |
+
|
157 |
+
def __len__(self):
|
158 |
+
return len(self.data_to_iterate)
|
datasets/mvec.py
ADDED
@@ -0,0 +1,301 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from torchvision import transforms
|
2 |
+
from .perlin import perlin_mask
|
3 |
+
from enum import Enum
|
4 |
+
|
5 |
+
import numpy as np
|
6 |
+
import pandas as pd
|
7 |
+
import logging
|
8 |
+
|
9 |
+
LOGGER = logging.getLogger(__name__)
|
10 |
+
import PIL
|
11 |
+
import torch
|
12 |
+
import os
|
13 |
+
import glob
|
14 |
+
|
15 |
+
_CLASSNAMES = [
|
16 |
+
"carpet",
|
17 |
+
"grid",
|
18 |
+
"leather",
|
19 |
+
"tile",
|
20 |
+
"wood",
|
21 |
+
"bottle",
|
22 |
+
"cable",
|
23 |
+
"capsule",
|
24 |
+
"hazelnut",
|
25 |
+
"metal_nut",
|
26 |
+
"pill",
|
27 |
+
"screw",
|
28 |
+
"toothbrush",
|
29 |
+
"transistor",
|
30 |
+
"zipper",
|
31 |
+
]
|
32 |
+
|
33 |
+
IMAGENET_MEAN = [0.485, 0.456, 0.406]
|
34 |
+
IMAGENET_STD = [0.229, 0.224, 0.225]
|
35 |
+
|
36 |
+
|
37 |
+
class DatasetSplit(Enum):
|
38 |
+
TRAIN = "train"
|
39 |
+
TEST = "test"
|
40 |
+
|
41 |
+
|
42 |
+
class MVTecDataset(torch.utils.data.Dataset):
|
43 |
+
"""
|
44 |
+
PyTorch Dataset for MVTec.
|
45 |
+
"""
|
46 |
+
|
47 |
+
def __init__(
|
48 |
+
self,
|
49 |
+
source,
|
50 |
+
anomaly_source_path='/root/dataset/dtd/images',
|
51 |
+
dataset_name='mvtec',
|
52 |
+
classname='leather',
|
53 |
+
resize=288,
|
54 |
+
imagesize=288,
|
55 |
+
split=DatasetSplit.TRAIN,
|
56 |
+
rotate_degrees=0,
|
57 |
+
translate=0,
|
58 |
+
brightness_factor=0,
|
59 |
+
contrast_factor=0,
|
60 |
+
saturation_factor=0,
|
61 |
+
gray_p=0,
|
62 |
+
h_flip_p=0,
|
63 |
+
v_flip_p=0,
|
64 |
+
distribution=0,
|
65 |
+
mean=0.5,
|
66 |
+
std=0.1,
|
67 |
+
fg=0,
|
68 |
+
rand_aug=1,
|
69 |
+
scale=0,
|
70 |
+
batch_size=8,
|
71 |
+
**kwargs,
|
72 |
+
):
|
73 |
+
"""
|
74 |
+
Args:
|
75 |
+
source: [str]. Path to the MVTec data folder.
|
76 |
+
classname: [str or None]. Name of MVTec class that should be
|
77 |
+
provided in this dataset. If None, the datasets
|
78 |
+
iterates over all available images.
|
79 |
+
resize: [int]. (Square) Size the loaded image initially gets
|
80 |
+
resized to.
|
81 |
+
imagesize: [int]. (Square) Size the resized loaded image gets
|
82 |
+
(center-)cropped to.
|
83 |
+
split: [enum-option]. Indicates if training or test split of the
|
84 |
+
data should be used. Has to be an option taken from
|
85 |
+
DatasetSplit, e.g. mvtec.DatasetSplit.TRAIN. Note that
|
86 |
+
mvtec.DatasetSplit.TEST will also load mask data.
|
87 |
+
"""
|
88 |
+
super().__init__()
|
89 |
+
self.source = source
|
90 |
+
self.split = split
|
91 |
+
self.batch_size = batch_size
|
92 |
+
self.distribution = distribution
|
93 |
+
self.mean = mean
|
94 |
+
self.std = std
|
95 |
+
self.fg = fg
|
96 |
+
self.rand_aug = rand_aug
|
97 |
+
self.resize = resize if self.distribution != 1 else [resize, resize]
|
98 |
+
self.imgsize = imagesize
|
99 |
+
self.imagesize = (3, self.imgsize, self.imgsize)
|
100 |
+
self.classname = classname
|
101 |
+
self.dataset_name = dataset_name
|
102 |
+
|
103 |
+
if self.distribution != 1 and (self.classname == 'toothbrush' or self.classname == 'wood'):
|
104 |
+
self.resize = round(self.imgsize * 329 / 288)
|
105 |
+
|
106 |
+
xlsx_path = './datasets/excel/' + self.dataset_name + '_distribution.xlsx'
|
107 |
+
if self.fg == 2: # choose by file
|
108 |
+
try:
|
109 |
+
df = pd.read_excel(xlsx_path)
|
110 |
+
self.class_fg = df.loc[df['Class'] == self.dataset_name + '_' + classname, 'Foreground'].values[0]
|
111 |
+
except:
|
112 |
+
self.class_fg = 1
|
113 |
+
elif self.fg == 1: # with foreground mask
|
114 |
+
self.class_fg = 1
|
115 |
+
else: # without foreground mask
|
116 |
+
self.class_fg = 0
|
117 |
+
|
118 |
+
self.imgpaths_per_class, self.data_to_iterate = self.get_image_data()
|
119 |
+
self.anomaly_source_paths = sorted(1 * glob.glob(anomaly_source_path + "/*/*/*/*.png") +
|
120 |
+
0 * list(next(iter(self.imgpaths_per_class.values())).values())[0])
|
121 |
+
print(self.anomaly_source_paths)
|
122 |
+
self.transform_img = [
|
123 |
+
transforms.Resize(self.resize),
|
124 |
+
transforms.ColorJitter(brightness_factor, contrast_factor, saturation_factor),
|
125 |
+
transforms.RandomHorizontalFlip(h_flip_p),
|
126 |
+
transforms.RandomVerticalFlip(v_flip_p),
|
127 |
+
transforms.RandomGrayscale(gray_p),
|
128 |
+
transforms.RandomAffine(rotate_degrees,
|
129 |
+
translate=(translate, translate),
|
130 |
+
scale=(1.0 - scale, 1.0 + scale),
|
131 |
+
interpolation=transforms.InterpolationMode.BILINEAR),
|
132 |
+
transforms.CenterCrop(self.imgsize),
|
133 |
+
transforms.ToTensor(),
|
134 |
+
transforms.Normalize(mean=IMAGENET_MEAN, std=IMAGENET_STD),
|
135 |
+
]
|
136 |
+
self.transform_img = transforms.Compose(self.transform_img)
|
137 |
+
|
138 |
+
self.transform_mask = [
|
139 |
+
transforms.Resize(self.resize),
|
140 |
+
transforms.CenterCrop(self.imgsize),
|
141 |
+
transforms.ToTensor(),
|
142 |
+
]
|
143 |
+
self.transform_mask = transforms.Compose(self.transform_mask)
|
144 |
+
|
145 |
+
def rand_augmenter(self):
|
146 |
+
list_aug = [
|
147 |
+
transforms.ColorJitter(contrast=(0.8, 1.2)),
|
148 |
+
transforms.ColorJitter(brightness=(0.8, 1.2)),
|
149 |
+
transforms.ColorJitter(saturation=(0.8, 1.2), hue=(-0.2, 0.2)),
|
150 |
+
transforms.RandomHorizontalFlip(p=1),
|
151 |
+
transforms.RandomVerticalFlip(p=1),
|
152 |
+
transforms.RandomGrayscale(p=1),
|
153 |
+
transforms.RandomAutocontrast(p=1),
|
154 |
+
transforms.RandomEqualize(p=1),
|
155 |
+
transforms.RandomAffine(degrees=(-45, 45)),
|
156 |
+
]
|
157 |
+
aug_idx = np.random.choice(np.arange(len(list_aug)), 3, replace=False)
|
158 |
+
|
159 |
+
transform_aug = [
|
160 |
+
transforms.Resize(self.resize),
|
161 |
+
list_aug[aug_idx[0]],
|
162 |
+
list_aug[aug_idx[1]],
|
163 |
+
list_aug[aug_idx[2]],
|
164 |
+
transforms.CenterCrop(self.imgsize),
|
165 |
+
transforms.ToTensor(),
|
166 |
+
transforms.Normalize(mean=IMAGENET_MEAN, std=IMAGENET_STD),
|
167 |
+
]
|
168 |
+
|
169 |
+
transform_aug = transforms.Compose(transform_aug)
|
170 |
+
return transform_aug
|
171 |
+
|
172 |
+
def __getitem__(self, idx):
|
173 |
+
try:
|
174 |
+
classname, anomaly, image_path, mask_path = self.data_to_iterate[idx]
|
175 |
+
|
176 |
+
# Load the main image
|
177 |
+
if not os.path.exists(image_path):
|
178 |
+
LOGGER.warning(f"Image not found: {image_path}. Skipping index {idx}.")
|
179 |
+
return None
|
180 |
+
|
181 |
+
image = PIL.Image.open(image_path).convert("RGB")
|
182 |
+
image = self.transform_img(image)
|
183 |
+
|
184 |
+
# Initialize default tensors
|
185 |
+
mask_fg = mask_s = aug_image = torch.tensor([1])
|
186 |
+
|
187 |
+
if self.split == DatasetSplit.TRAIN:
|
188 |
+
try:
|
189 |
+
aug = PIL.Image.open(np.random.choice(self.anomaly_source_paths)).convert("RGB")
|
190 |
+
if self.rand_aug:
|
191 |
+
transform_aug = self.rand_augmenter()
|
192 |
+
aug = transform_aug(aug)
|
193 |
+
else:
|
194 |
+
aug = self.transform_img(aug)
|
195 |
+
except IndexError:
|
196 |
+
LOGGER.warning(f"No anomaly source images available. Using original image as augmentation for index {idx}.")
|
197 |
+
aug = image # Use original image if no anomaly source images
|
198 |
+
|
199 |
+
# Handle foreground mask
|
200 |
+
if self.class_fg:
|
201 |
+
fgmask_path = (
|
202 |
+
image_path.split(classname)[0]
|
203 |
+
+ classname
|
204 |
+
+ "/ground_truth/"
|
205 |
+
+ os.path.split(image_path)[-1].replace(".png", "_mask.png")
|
206 |
+
)
|
207 |
+
if os.path.exists(fgmask_path):
|
208 |
+
mask_fg = PIL.Image.open(fgmask_path)
|
209 |
+
mask_fg = torch.ceil(self.transform_mask(mask_fg)[0])
|
210 |
+
else:
|
211 |
+
LOGGER.warning(f"Foreground mask not found: {fgmask_path}. Skipping mask for index {idx}.")
|
212 |
+
mask_fg = torch.zeros_like(image[0]) # Default empty mask
|
213 |
+
|
214 |
+
# Generate masks and augmented images
|
215 |
+
mask_all = perlin_mask(image.shape, self.imgsize // 8, 0, 6, mask_fg, 1)
|
216 |
+
mask_s = torch.from_numpy(mask_all[0])
|
217 |
+
mask_l = torch.from_numpy(mask_all[1])
|
218 |
+
|
219 |
+
beta = np.random.normal(loc=self.mean, scale=self.std)
|
220 |
+
beta = np.clip(beta, 0.2, 0.8)
|
221 |
+
aug_image = image * (1 - mask_l) + (1 - beta) * aug * mask_l + beta * image * mask_l
|
222 |
+
|
223 |
+
if self.split == DatasetSplit.TEST and mask_path is not None:
|
224 |
+
if os.path.exists(mask_path):
|
225 |
+
mask_gt = PIL.Image.open(mask_path).convert("L")
|
226 |
+
mask_gt = self.transform_mask(mask_gt)
|
227 |
+
else:
|
228 |
+
LOGGER.warning(f"Ground truth mask not found: {mask_path}. Using default empty mask for index {idx}.")
|
229 |
+
mask_gt = torch.zeros([1, *image.size()[1:]])
|
230 |
+
else:
|
231 |
+
mask_gt = torch.zeros([1, *image.size()[1:]])
|
232 |
+
|
233 |
+
return {
|
234 |
+
"image": image,
|
235 |
+
"aug": aug_image,
|
236 |
+
"mask_s": mask_s,
|
237 |
+
"mask_gt": mask_gt,
|
238 |
+
"is_anomaly": int(anomaly != "good"),
|
239 |
+
"image_path": image_path,
|
240 |
+
}
|
241 |
+
|
242 |
+
except Exception as e:
|
243 |
+
LOGGER.error(f"Error processing index {idx}: {e}")
|
244 |
+
return None
|
245 |
+
|
246 |
+
|
247 |
+
def __len__(self):
|
248 |
+
return len(self.data_to_iterate)
|
249 |
+
|
250 |
+
def get_image_data(self):
|
251 |
+
imgpaths_per_class = {}
|
252 |
+
maskpaths_per_class = {}
|
253 |
+
|
254 |
+
classpath = os.path.join(self.source, self.classname, self.split.value)
|
255 |
+
maskpath = os.path.join(self.source, self.classname, "ground_truth")
|
256 |
+
anomaly_types = os.listdir(classpath)
|
257 |
+
|
258 |
+
imgpaths_per_class[self.classname] = {}
|
259 |
+
maskpaths_per_class[self.classname] = {}
|
260 |
+
|
261 |
+
for anomaly in anomaly_types:
|
262 |
+
anomaly_path = os.path.join(classpath, anomaly)
|
263 |
+
anomaly_files = sorted(os.listdir(anomaly_path))
|
264 |
+
imgpaths_per_class[self.classname][anomaly] = [os.path.join(anomaly_path, x) for x in anomaly_files]
|
265 |
+
|
266 |
+
if self.split == DatasetSplit.TEST and anomaly != "good":
|
267 |
+
anomaly_mask_path = os.path.join(maskpath, anomaly)
|
268 |
+
if os.path.exists(anomaly_mask_path):
|
269 |
+
anomaly_mask_files = sorted(os.listdir(anomaly_mask_path))
|
270 |
+
maskpaths_per_class[self.classname][anomaly] = [os.path.join(anomaly_mask_path, x) for x in anomaly_mask_files]
|
271 |
+
else:
|
272 |
+
LOGGER.warning(f"Anomaly mask path does not exist: {anomaly_mask_path}. Skipping masks for {anomaly}.")
|
273 |
+
maskpaths_per_class[self.classname][anomaly] = []
|
274 |
+
else:
|
275 |
+
maskpaths_per_class[self.classname]["good"] = None
|
276 |
+
|
277 |
+
data_to_iterate = []
|
278 |
+
for classname in sorted(imgpaths_per_class.keys()):
|
279 |
+
for anomaly in sorted(imgpaths_per_class[classname].keys()):
|
280 |
+
for i, image_path in enumerate(imgpaths_per_class[classname][anomaly]):
|
281 |
+
try:
|
282 |
+
if self.split == DatasetSplit.TEST and anomaly != "good":
|
283 |
+
if i < len(maskpaths_per_class[classname][anomaly]):
|
284 |
+
mask_path = maskpaths_per_class[classname][anomaly][i]
|
285 |
+
else:
|
286 |
+
LOGGER.warning(f"No corresponding mask for {image_path}. Skipping.")
|
287 |
+
continue
|
288 |
+
else:
|
289 |
+
mask_path = None
|
290 |
+
|
291 |
+
if os.path.exists(image_path) and (mask_path is None or os.path.exists(mask_path)):
|
292 |
+
data_to_iterate.append([classname, anomaly, image_path, mask_path])
|
293 |
+
else:
|
294 |
+
LOGGER.warning(f"Missing required file for {image_path} or {mask_path}. Skipping.")
|
295 |
+
except Exception as e:
|
296 |
+
LOGGER.error(f"Error processing file {image_path}: {e}. Skipping.")
|
297 |
+
|
298 |
+
if len(data_to_iterate) == 0:
|
299 |
+
raise ValueError("No valid data found. Please check dataset paths and files.")
|
300 |
+
|
301 |
+
return imgpaths_per_class, data_to_iterate
|
datasets/perlin.py
ADDED
@@ -0,0 +1,73 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import imgaug.augmenters as iaa
|
2 |
+
import numpy as np
|
3 |
+
import torch
|
4 |
+
import math
|
5 |
+
|
6 |
+
|
7 |
+
def generate_thr(img_shape, min=0, max=4):
|
8 |
+
min_perlin_scale = min
|
9 |
+
max_perlin_scale = max
|
10 |
+
perlin_scalex = 2 ** np.random.randint(min_perlin_scale, max_perlin_scale)
|
11 |
+
perlin_scaley = 2 ** np.random.randint(min_perlin_scale, max_perlin_scale)
|
12 |
+
perlin_noise_np = rand_perlin_2d_np((img_shape[1], img_shape[2]), (perlin_scalex, perlin_scaley))
|
13 |
+
threshold = 0.5
|
14 |
+
perlin_noise_np = iaa.Sequential([iaa.Affine(rotate=(-90, 90))])(image=perlin_noise_np)
|
15 |
+
perlin_thr = np.where(perlin_noise_np > threshold, np.ones_like(perlin_noise_np), np.zeros_like(perlin_noise_np))
|
16 |
+
return perlin_thr
|
17 |
+
|
18 |
+
|
19 |
+
def perlin_mask(img_shape, feat_size, min, max, mask_fg, flag=0):
|
20 |
+
mask = np.zeros((feat_size, feat_size))
|
21 |
+
while np.max(mask) == 0:
|
22 |
+
perlin_thr_1 = generate_thr(img_shape, min, max)
|
23 |
+
perlin_thr_2 = generate_thr(img_shape, min, max)
|
24 |
+
temp = torch.rand(1).numpy()[0]
|
25 |
+
if temp > 2 / 3:
|
26 |
+
perlin_thr = perlin_thr_1 + perlin_thr_2
|
27 |
+
perlin_thr = np.where(perlin_thr > 0, np.ones_like(perlin_thr), np.zeros_like(perlin_thr))
|
28 |
+
elif temp > 1 / 3:
|
29 |
+
perlin_thr = perlin_thr_1 * perlin_thr_2
|
30 |
+
else:
|
31 |
+
perlin_thr = perlin_thr_1
|
32 |
+
perlin_thr = torch.from_numpy(perlin_thr)
|
33 |
+
perlin_thr_fg = perlin_thr * mask_fg
|
34 |
+
down_ratio_y = int(img_shape[1] / feat_size)
|
35 |
+
down_ratio_x = int(img_shape[2] / feat_size)
|
36 |
+
mask_ = perlin_thr_fg
|
37 |
+
mask = torch.nn.functional.max_pool2d(perlin_thr_fg.unsqueeze(0).unsqueeze(0), (down_ratio_y, down_ratio_x)).float()
|
38 |
+
mask = mask.numpy()[0, 0]
|
39 |
+
mask_s = mask
|
40 |
+
if flag != 0:
|
41 |
+
mask_l = mask_.numpy()
|
42 |
+
if flag == 0:
|
43 |
+
return mask_s
|
44 |
+
else:
|
45 |
+
return mask_s, mask_l
|
46 |
+
|
47 |
+
|
48 |
+
def lerp_np(x, y, w):
|
49 |
+
fin_out = (y - x) * w + x
|
50 |
+
return fin_out
|
51 |
+
|
52 |
+
|
53 |
+
def rand_perlin_2d_np(shape, res, fade=lambda t: 6 * t ** 5 - 15 * t ** 4 + 10 * t ** 3):
|
54 |
+
delta = (res[0] / shape[0], res[1] / shape[1])
|
55 |
+
d = (shape[0] // res[0], shape[1] // res[1])
|
56 |
+
grid = np.mgrid[0:res[0]:delta[0], 0:res[1]:delta[1]].transpose(1, 2, 0) % 1
|
57 |
+
|
58 |
+
angles = 2 * math.pi * np.random.rand(res[0] + 1, res[1] + 1)
|
59 |
+
gradients = np.stack((np.cos(angles), np.sin(angles)), axis=-1)
|
60 |
+
tt = np.repeat(np.repeat(gradients, d[0], axis=0), d[1], axis=1)
|
61 |
+
|
62 |
+
tile_grads = lambda slice1, slice2: np.repeat(np.repeat(gradients[slice1[0]:slice1[1], slice2[0]:slice2[1]], d[0], axis=0), d[1],
|
63 |
+
axis=1)
|
64 |
+
dot = lambda grad, shift: (
|
65 |
+
np.stack((grid[:shape[0], :shape[1], 0] + shift[0], grid[:shape[0], :shape[1], 1] + shift[1]),
|
66 |
+
axis=-1) * grad[:shape[0], :shape[1]]).sum(axis=-1)
|
67 |
+
|
68 |
+
n00 = dot(tile_grads([0, -1], [0, -1]), [0, 0])
|
69 |
+
n10 = dot(tile_grads([1, None], [0, -1]), [-1, 0])
|
70 |
+
n01 = dot(tile_grads([0, -1], [1, None]), [0, -1])
|
71 |
+
n11 = dot(tile_grads([1, None], [1, None]), [-1, -1])
|
72 |
+
t = fade(grid[:shape[0], :shape[1]])
|
73 |
+
return math.sqrt(2) * lerp_np(lerp_np(n00, n10, t[..., 0]), lerp_np(n01, n11, t[..., 0]), t[..., 1])
|
datasets/rayan_dataset.py
ADDED
@@ -0,0 +1,127 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# -----------------------------------------------------------------------------
|
2 |
+
# Do Not Alter This File!
|
3 |
+
# -----------------------------------------------------------------------------
|
4 |
+
# The following code is part of the logic used for loading and evaluating your
|
5 |
+
# output scores. Please DO NOT modify this section, as upon your submission,
|
6 |
+
# the whole evaluation logic will be overwritten by the original code.
|
7 |
+
# -----------------------------------------------------------------------------
|
8 |
+
# If you'd like to make modifications, you can create a completely new Dataset
|
9 |
+
# class or a child class that inherits from this one and use that with your
|
10 |
+
# data loader.
|
11 |
+
# -----------------------------------------------------------------------------
|
12 |
+
|
13 |
+
import os
|
14 |
+
from enum import Enum
|
15 |
+
|
16 |
+
import PIL
|
17 |
+
import torch
|
18 |
+
from torchvision import transforms
|
19 |
+
|
20 |
+
IMAGENET_MEAN = [0.485, 0.456, 0.406]
|
21 |
+
IMAGENET_STD = [0.229, 0.224, 0.225]
|
22 |
+
|
23 |
+
|
24 |
+
class DatasetSplit(Enum):
|
25 |
+
TRAIN = "train"
|
26 |
+
VAL = "val"
|
27 |
+
TEST = "test"
|
28 |
+
|
29 |
+
|
30 |
+
class RayanDataset(torch.utils.data.Dataset):
|
31 |
+
def __init__(
|
32 |
+
self,
|
33 |
+
source,
|
34 |
+
classname,
|
35 |
+
input_size=518,
|
36 |
+
output_size=224,
|
37 |
+
split=DatasetSplit.TEST,
|
38 |
+
external_transform=None,
|
39 |
+
**kwargs,
|
40 |
+
):
|
41 |
+
super().__init__()
|
42 |
+
self.source = source
|
43 |
+
self.split = split
|
44 |
+
self.classnames_to_use = [classname]
|
45 |
+
self.imgpaths_per_class, self.data_to_iterate = self.get_image_data()
|
46 |
+
|
47 |
+
if external_transform is None:
|
48 |
+
self.transform_img = [
|
49 |
+
transforms.Resize((input_size, input_size)),
|
50 |
+
transforms.CenterCrop(input_size),
|
51 |
+
transforms.ToTensor(),
|
52 |
+
transforms.Normalize(mean=IMAGENET_MEAN, std=IMAGENET_STD),
|
53 |
+
]
|
54 |
+
self.transform_img = transforms.Compose(self.transform_img)
|
55 |
+
else:
|
56 |
+
self.transform_img = external_transform
|
57 |
+
|
58 |
+
# Output size of the mask has to be of shape: 1×224×224
|
59 |
+
self.transform_mask = [
|
60 |
+
transforms.Resize((output_size, output_size)),
|
61 |
+
transforms.CenterCrop(output_size),
|
62 |
+
transforms.ToTensor(),
|
63 |
+
]
|
64 |
+
self.transform_mask = transforms.Compose(self.transform_mask)
|
65 |
+
self.output_shape = (1, output_size, output_size)
|
66 |
+
|
67 |
+
def __getitem__(self, idx):
|
68 |
+
classname, anomaly, image_path, mask_path = self.data_to_iterate[idx]
|
69 |
+
image = PIL.Image.open(image_path).convert("RGB")
|
70 |
+
image = self.transform_img(image)
|
71 |
+
|
72 |
+
if self.split == DatasetSplit.TEST and mask_path is not None:
|
73 |
+
mask = PIL.Image.open(mask_path).convert("L")
|
74 |
+
mask = self.transform_mask(mask) > 0
|
75 |
+
else:
|
76 |
+
mask = torch.zeros([*self.output_shape])
|
77 |
+
|
78 |
+
return {
|
79 |
+
"image": image,
|
80 |
+
"mask": mask,
|
81 |
+
"is_anomaly": int(anomaly != "good"),
|
82 |
+
"image_path": image_path,
|
83 |
+
}
|
84 |
+
|
85 |
+
def __len__(self):
|
86 |
+
return len(self.data_to_iterate)
|
87 |
+
|
88 |
+
def get_image_data(self):
|
89 |
+
imgpaths_per_class = {}
|
90 |
+
maskpaths_per_class = {}
|
91 |
+
|
92 |
+
for classname in self.classnames_to_use:
|
93 |
+
classpath = os.path.join(self.source, classname, self.split.value)
|
94 |
+
maskpath = os.path.join(self.source, classname, "ground_truth")
|
95 |
+
anomaly_types = os.listdir(classpath)
|
96 |
+
|
97 |
+
imgpaths_per_class[classname] = {}
|
98 |
+
maskpaths_per_class[classname] = {}
|
99 |
+
|
100 |
+
for anomaly in anomaly_types:
|
101 |
+
anomaly_path = os.path.join(classpath, anomaly)
|
102 |
+
anomaly_files = sorted(os.listdir(anomaly_path))
|
103 |
+
imgpaths_per_class[classname][anomaly] = [
|
104 |
+
os.path.join(anomaly_path, x) for x in anomaly_files
|
105 |
+
]
|
106 |
+
|
107 |
+
if self.split == DatasetSplit.TEST and anomaly != "good":
|
108 |
+
anomaly_mask_path = os.path.join(maskpath, anomaly)
|
109 |
+
anomaly_mask_files = sorted(os.listdir(anomaly_mask_path))
|
110 |
+
maskpaths_per_class[classname][anomaly] = [
|
111 |
+
os.path.join(anomaly_mask_path, x) for x in anomaly_mask_files
|
112 |
+
]
|
113 |
+
else:
|
114 |
+
maskpaths_per_class[classname]["good"] = None
|
115 |
+
|
116 |
+
data_to_iterate = []
|
117 |
+
for classname in sorted(imgpaths_per_class.keys()):
|
118 |
+
for anomaly in sorted(imgpaths_per_class[classname].keys()):
|
119 |
+
for i, image_path in enumerate(imgpaths_per_class[classname][anomaly]):
|
120 |
+
data_tuple = [classname, anomaly, image_path]
|
121 |
+
if self.split == DatasetSplit.TEST and anomaly != "good":
|
122 |
+
data_tuple.append(maskpaths_per_class[classname][anomaly][i])
|
123 |
+
else:
|
124 |
+
data_tuple.append(None)
|
125 |
+
data_to_iterate.append(data_tuple)
|
126 |
+
|
127 |
+
return imgpaths_per_class, data_to_iterate
|
docker-compose.yml
ADDED
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# -----------------------------------------------------------------------------
|
2 |
+
# A sample Docker Compose file to help you replicate our test environment
|
3 |
+
# -----------------------------------------------------------------------------
|
4 |
+
|
5 |
+
services:
|
6 |
+
zsad-service:
|
7 |
+
image: zsad-image:1
|
8 |
+
build:
|
9 |
+
context: .
|
10 |
+
container_name: zsad-container
|
11 |
+
volumes:
|
12 |
+
- ./shared_folder:/app/output
|
13 |
+
deploy:
|
14 |
+
resources:
|
15 |
+
reservations:
|
16 |
+
devices:
|
17 |
+
- driver: nvidia
|
18 |
+
count: all
|
19 |
+
capabilities: [gpu]
|
20 |
+
|
21 |
+
command: [ "python3", "runner.py" ]
|
evaluation/base_eval.py
ADDED
@@ -0,0 +1,293 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# -----------------------------------------------------------------------------
|
2 |
+
# Do Not Alter This File!
|
3 |
+
# -----------------------------------------------------------------------------
|
4 |
+
# The following code is part of the logic used for loading and evaluating your
|
5 |
+
# output scores. Please DO NOT modify this section, as upon your submission,
|
6 |
+
# the whole evaluation logic will be overwritten by the original code.
|
7 |
+
# -----------------------------------------------------------------------------
|
8 |
+
|
9 |
+
import warnings
|
10 |
+
import os
|
11 |
+
from pathlib import Path
|
12 |
+
import csv
|
13 |
+
import json
|
14 |
+
import torch
|
15 |
+
|
16 |
+
import datasets.rayan_dataset as rayan_dataset
|
17 |
+
from evaluation.utils.metrics import compute_metrics
|
18 |
+
|
19 |
+
warnings.filterwarnings("ignore")
|
20 |
+
|
21 |
+
|
22 |
+
class BaseEval:
|
23 |
+
def __init__(self, cfg):
|
24 |
+
self.cfg = cfg
|
25 |
+
self.device = torch.device(
|
26 |
+
"cuda:{}".format(cfg["device"]) if torch.cuda.is_available() else "cpu"
|
27 |
+
)
|
28 |
+
|
29 |
+
self.path = cfg["datasets"]["data_path"]
|
30 |
+
self.dataset = cfg["datasets"]["dataset_name"]
|
31 |
+
self.save_csv = cfg["testing"]["save_csv"]
|
32 |
+
self.save_json = cfg["testing"]["save_json"]
|
33 |
+
self.categories = cfg["datasets"]["class_name"]
|
34 |
+
if isinstance(self.categories, str):
|
35 |
+
if self.categories.lower() == "all":
|
36 |
+
if self.dataset == "rayan_dataset":
|
37 |
+
self.categories = self.get_available_class_names(self.path)
|
38 |
+
else:
|
39 |
+
self.categories = [self.categories]
|
40 |
+
self.output_dir = cfg["testing"]["output_dir"]
|
41 |
+
os.makedirs(self.output_dir, exist_ok=True)
|
42 |
+
self.scores_dir = cfg["testing"]["output_scores_dir"]
|
43 |
+
self.class_name_mapping_dir = cfg["testing"]["class_name_mapping_dir"]
|
44 |
+
|
45 |
+
self.leaderboard_metric_weights = {
|
46 |
+
"image_auroc": 1.2,
|
47 |
+
"image_ap": 1.1,
|
48 |
+
"image_f1": 1.1,
|
49 |
+
"pixel_auroc": 1.0,
|
50 |
+
"pixel_aupro": 1.4,
|
51 |
+
"pixel_ap": 1.3,
|
52 |
+
"pixel_f1": 1.3,
|
53 |
+
}
|
54 |
+
|
55 |
+
def get_available_class_names(self, root_data_path):
|
56 |
+
all_items = os.listdir(root_data_path)
|
57 |
+
folder_names = [
|
58 |
+
item
|
59 |
+
for item in all_items
|
60 |
+
if os.path.isdir(os.path.join(root_data_path, item))
|
61 |
+
]
|
62 |
+
|
63 |
+
return folder_names
|
64 |
+
|
65 |
+
def load_datasets(self, category):
|
66 |
+
dataset_classes = {
|
67 |
+
"rayan_dataset": rayan_dataset.RayanDataset,
|
68 |
+
}
|
69 |
+
|
70 |
+
dataset_splits = {
|
71 |
+
"rayan_dataset": rayan_dataset.DatasetSplit.TEST,
|
72 |
+
}
|
73 |
+
|
74 |
+
test_dataset = dataset_classes[self.dataset](
|
75 |
+
source=self.path,
|
76 |
+
split=dataset_splits[self.dataset],
|
77 |
+
classname=category,
|
78 |
+
)
|
79 |
+
return test_dataset
|
80 |
+
|
81 |
+
def get_category_metrics(self, category):
|
82 |
+
print(f"Loading scores of '{category}'")
|
83 |
+
gt_sp, pr_sp, gt_px, pr_px, _ = self.load_category_scores(category)
|
84 |
+
|
85 |
+
print(f"Computing metrics for '{category}'")
|
86 |
+
image_metric, pixel_metric = compute_metrics(gt_sp, pr_sp, gt_px, pr_px)
|
87 |
+
|
88 |
+
return image_metric, pixel_metric
|
89 |
+
|
90 |
+
def load_category_scores(self, category):
|
91 |
+
raise NotImplementedError()
|
92 |
+
|
93 |
+
def get_scores_path_for_image(self, image_path):
|
94 |
+
"""example image_path: './data/photovoltaic_module/test/good/037.png'"""
|
95 |
+
path = Path(image_path)
|
96 |
+
|
97 |
+
category, split, anomaly_type = path.parts[-4:-1]
|
98 |
+
image_name = path.stem
|
99 |
+
|
100 |
+
return os.path.join(
|
101 |
+
self.scores_dir, category, split, anomaly_type, f"{image_name}_scores.json"
|
102 |
+
)
|
103 |
+
|
104 |
+
def calc_leaderboard_score(self, **metrics):
|
105 |
+
weighted_sum = 0
|
106 |
+
total_weight = 0
|
107 |
+
for key, weight in self.leaderboard_metric_weights.items():
|
108 |
+
metric = metrics.get(key)
|
109 |
+
weighted_sum += metric * weight
|
110 |
+
total_weight += weight
|
111 |
+
|
112 |
+
if total_weight == 0:
|
113 |
+
return 0
|
114 |
+
|
115 |
+
return weighted_sum / total_weight
|
116 |
+
|
117 |
+
def main(self):
|
118 |
+
image_auroc_list = []
|
119 |
+
image_f1_list = []
|
120 |
+
image_ap_list = []
|
121 |
+
pixel_auroc_list = []
|
122 |
+
pixel_f1_list = []
|
123 |
+
pixel_ap_list = []
|
124 |
+
pixel_aupro_list = []
|
125 |
+
leaderboard_score_list = []
|
126 |
+
for category in self.categories:
|
127 |
+
image_metric, pixel_metric = self.get_category_metrics(
|
128 |
+
category=category,
|
129 |
+
)
|
130 |
+
image_auroc, image_f1, image_ap = image_metric
|
131 |
+
pixel_auroc, pixel_f1, pixel_ap, pixel_aupro = pixel_metric
|
132 |
+
leaderboard_score = self.calc_leaderboard_score(
|
133 |
+
image_auroc=image_auroc,
|
134 |
+
image_f1=image_f1,
|
135 |
+
image_ap=image_ap,
|
136 |
+
pixel_auroc=pixel_auroc,
|
137 |
+
pixel_aupro=pixel_aupro,
|
138 |
+
pixel_f1=pixel_f1,
|
139 |
+
pixel_ap=pixel_ap,
|
140 |
+
)
|
141 |
+
|
142 |
+
image_auroc_list.append(image_auroc)
|
143 |
+
image_f1_list.append(image_f1)
|
144 |
+
image_ap_list.append(image_ap)
|
145 |
+
pixel_auroc_list.append(pixel_auroc)
|
146 |
+
pixel_f1_list.append(pixel_f1)
|
147 |
+
pixel_ap_list.append(pixel_ap)
|
148 |
+
pixel_aupro_list.append(pixel_aupro)
|
149 |
+
leaderboard_score_list.append(leaderboard_score)
|
150 |
+
|
151 |
+
print(category)
|
152 |
+
print(
|
153 |
+
"[image level] auroc:{}, f1:{}, ap:{}".format(
|
154 |
+
image_auroc * 100,
|
155 |
+
image_f1 * 100,
|
156 |
+
image_ap * 100,
|
157 |
+
)
|
158 |
+
)
|
159 |
+
print(
|
160 |
+
"[pixel level] auroc:{}, f1:{}, ap:{}, aupro:{}".format(
|
161 |
+
pixel_auroc * 100,
|
162 |
+
pixel_f1 * 100,
|
163 |
+
pixel_ap * 100,
|
164 |
+
pixel_aupro * 100,
|
165 |
+
)
|
166 |
+
)
|
167 |
+
print(
|
168 |
+
"leaderboard score:{}".format(
|
169 |
+
leaderboard_score * 100,
|
170 |
+
)
|
171 |
+
)
|
172 |
+
|
173 |
+
image_auroc_mean = sum(image_auroc_list) / len(image_auroc_list)
|
174 |
+
image_f1_mean = sum(image_f1_list) / len(image_f1_list)
|
175 |
+
image_ap_mean = sum(image_ap_list) / len(image_ap_list)
|
176 |
+
pixel_auroc_mean = sum(pixel_auroc_list) / len(pixel_auroc_list)
|
177 |
+
pixel_f1_mean = sum(pixel_f1_list) / len(pixel_f1_list)
|
178 |
+
pixel_ap_mean = sum(pixel_ap_list) / len(pixel_ap_list)
|
179 |
+
pixel_aupro_mean = sum(pixel_aupro_list) / len(pixel_aupro_list)
|
180 |
+
leaderboard_score_mean = sum(leaderboard_score_list) / len(
|
181 |
+
leaderboard_score_list
|
182 |
+
)
|
183 |
+
|
184 |
+
print("mean")
|
185 |
+
print(
|
186 |
+
"[image level] auroc:{}, f1:{}, ap:{}".format(
|
187 |
+
image_auroc_mean * 100, image_f1_mean * 100, image_ap_mean * 100
|
188 |
+
)
|
189 |
+
)
|
190 |
+
print(
|
191 |
+
"[pixel level] auroc:{}, f1:{}, ap:{}, aupro:{}".format(
|
192 |
+
pixel_auroc_mean * 100,
|
193 |
+
pixel_f1_mean * 100,
|
194 |
+
pixel_ap_mean * 100,
|
195 |
+
pixel_aupro_mean * 100,
|
196 |
+
)
|
197 |
+
)
|
198 |
+
print(
|
199 |
+
"leaderboard score:{}".format(
|
200 |
+
leaderboard_score_mean * 100,
|
201 |
+
)
|
202 |
+
)
|
203 |
+
|
204 |
+
# Save the final results as a csv file
|
205 |
+
if self.save_csv:
|
206 |
+
with open(self.class_name_mapping_dir, "r") as f:
|
207 |
+
class_name_mapping_dict = json.load(f)
|
208 |
+
csv_data = [
|
209 |
+
[
|
210 |
+
"Category",
|
211 |
+
"pixel_auroc",
|
212 |
+
"pixel_f1",
|
213 |
+
"pixel_ap",
|
214 |
+
"pixel_aupro",
|
215 |
+
"image_auroc",
|
216 |
+
"image_f1",
|
217 |
+
"image_ap",
|
218 |
+
"leaderboard_score",
|
219 |
+
]
|
220 |
+
]
|
221 |
+
for i, category in enumerate(self.categories):
|
222 |
+
csv_data.append(
|
223 |
+
[
|
224 |
+
class_name_mapping_dict[category],
|
225 |
+
pixel_auroc_list[i] * 100,
|
226 |
+
pixel_f1_list[i] * 100,
|
227 |
+
pixel_ap_list[i] * 100,
|
228 |
+
pixel_aupro_list[i] * 100,
|
229 |
+
image_auroc_list[i] * 100,
|
230 |
+
image_f1_list[i] * 100,
|
231 |
+
image_ap_list[i] * 100,
|
232 |
+
leaderboard_score_list[i] * 100,
|
233 |
+
]
|
234 |
+
)
|
235 |
+
csv_data.append(
|
236 |
+
[
|
237 |
+
"mean",
|
238 |
+
pixel_auroc_mean * 100,
|
239 |
+
pixel_f1_mean * 100,
|
240 |
+
pixel_ap_mean * 100,
|
241 |
+
pixel_aupro_mean * 100,
|
242 |
+
image_auroc_mean * 100,
|
243 |
+
image_f1_mean * 100,
|
244 |
+
image_ap_mean * 100,
|
245 |
+
leaderboard_score_mean * 100,
|
246 |
+
]
|
247 |
+
)
|
248 |
+
|
249 |
+
csv_file_path = os.path.join(self.output_dir, "results.csv")
|
250 |
+
with open(csv_file_path, mode="w", newline="") as file:
|
251 |
+
writer = csv.writer(file)
|
252 |
+
writer.writerows(csv_data)
|
253 |
+
|
254 |
+
# Save the final results as a json file
|
255 |
+
if self.save_json:
|
256 |
+
json_data = []
|
257 |
+
with open(self.class_name_mapping_dir, "r") as f:
|
258 |
+
class_name_mapping_dict = json.load(f)
|
259 |
+
for i, category in enumerate(self.categories):
|
260 |
+
json_data.append(
|
261 |
+
{
|
262 |
+
"Category": class_name_mapping_dict[category],
|
263 |
+
"pixel_auroc": pixel_auroc_list[i] * 100,
|
264 |
+
"pixel_f1": pixel_f1_list[i] * 100,
|
265 |
+
"pixel_ap": pixel_ap_list[i] * 100,
|
266 |
+
"pixel_aupro": pixel_aupro_list[i] * 100,
|
267 |
+
"image_auroc": image_auroc_list[i] * 100,
|
268 |
+
"image_f1": image_f1_list[i] * 100,
|
269 |
+
"image_ap": image_ap_list[i] * 100,
|
270 |
+
"leaderboard_score": leaderboard_score_list[i] * 100,
|
271 |
+
}
|
272 |
+
)
|
273 |
+
json_data.append(
|
274 |
+
{
|
275 |
+
"Category": "mean",
|
276 |
+
"pixel_auroc": pixel_auroc_mean * 100,
|
277 |
+
"pixel_f1": pixel_f1_mean * 100,
|
278 |
+
"pixel_ap": pixel_ap_mean * 100,
|
279 |
+
"pixel_aupro": pixel_aupro_mean * 100,
|
280 |
+
"image_auroc": image_auroc_mean * 100,
|
281 |
+
"image_f1": image_f1_mean * 100,
|
282 |
+
"image_ap": image_ap_mean * 100,
|
283 |
+
"leaderboard_score": leaderboard_score_mean * 100,
|
284 |
+
}
|
285 |
+
)
|
286 |
+
|
287 |
+
json_file_path = os.path.join(self.output_dir, "results.json")
|
288 |
+
with open(json_file_path, mode="w") as file:
|
289 |
+
final_json = {
|
290 |
+
"result": leaderboard_score_mean * 100,
|
291 |
+
"metadata": json_data,
|
292 |
+
}
|
293 |
+
json.dump(final_json, file, indent=4)
|
evaluation/class_name_mapping.json
ADDED
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"pill": "industrial_01",
|
3 |
+
"photovoltaic_module": "industrial_02",
|
4 |
+
"capsules": "industrial_03"
|
5 |
+
}
|
evaluation/eval_main.py
ADDED
@@ -0,0 +1,78 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# -----------------------------------------------------------------------------
|
2 |
+
# Do Not Alter This File!
|
3 |
+
# -----------------------------------------------------------------------------
|
4 |
+
# The following code is part of the logic used for loading and evaluating your
|
5 |
+
# output scores. Please DO NOT modify this section, as upon your submission,
|
6 |
+
# the whole evaluation logic will be overwritten by the original code.
|
7 |
+
# -----------------------------------------------------------------------------
|
8 |
+
|
9 |
+
import warnings
|
10 |
+
import argparse
|
11 |
+
import os
|
12 |
+
import sys
|
13 |
+
|
14 |
+
sys.path.append(os.getcwd())
|
15 |
+
from evaluation.json_score import JsonScoreEvaluator
|
16 |
+
|
17 |
+
warnings.filterwarnings("ignore")
|
18 |
+
|
19 |
+
|
20 |
+
def get_args():
|
21 |
+
parser = argparse.ArgumentParser(description="Rayan ZSAD Evaluation Code")
|
22 |
+
parser.add_argument("--data_path", type=str, default=None, help="dataset path")
|
23 |
+
parser.add_argument("--dataset_name", type=str, default=None, help="dataset name")
|
24 |
+
parser.add_argument("--class_name", type=str, default=None, help="category")
|
25 |
+
parser.add_argument("--device", type=int, default=None, help="gpu id")
|
26 |
+
parser.add_argument(
|
27 |
+
"--output_dir", type=str, default=None, help="save results path"
|
28 |
+
)
|
29 |
+
parser.add_argument(
|
30 |
+
"--output_scores_dir", type=str, default=None, help="save scores path"
|
31 |
+
)
|
32 |
+
parser.add_argument("--save_csv", type=str, default=None, help="save csv")
|
33 |
+
parser.add_argument("--save_json", type=str, default=None, help="save json")
|
34 |
+
|
35 |
+
parser.add_argument(
|
36 |
+
"--class_name_mapping_dir",
|
37 |
+
type=str,
|
38 |
+
default=None,
|
39 |
+
help="mapping from actual class names to class numbers",
|
40 |
+
)
|
41 |
+
args = parser.parse_args()
|
42 |
+
return args
|
43 |
+
|
44 |
+
|
45 |
+
def load_args(cfg, args):
|
46 |
+
cfg["datasets"]["data_path"] = args.data_path
|
47 |
+
assert os.path.exists(
|
48 |
+
cfg["datasets"]["data_path"]
|
49 |
+
), f"The dataset path {cfg['datasets']['data_path']} does not exist."
|
50 |
+
cfg["datasets"]["dataset_name"] = args.dataset_name
|
51 |
+
cfg["datasets"]["class_name"] = args.class_name
|
52 |
+
cfg["device"] = args.device
|
53 |
+
if isinstance(cfg["device"], int):
|
54 |
+
cfg["device"] = str(cfg["device"])
|
55 |
+
cfg["testing"]["output_dir"] = args.output_dir
|
56 |
+
cfg["testing"]["output_scores_dir"] = args.output_scores_dir
|
57 |
+
os.makedirs(cfg["testing"]["output_scores_dir"], exist_ok=True)
|
58 |
+
|
59 |
+
cfg["testing"]["class_name_mapping_dir"] = args.class_name_mapping_dir
|
60 |
+
if args.save_csv.lower() == "true":
|
61 |
+
cfg["testing"]["save_csv"] = True
|
62 |
+
else:
|
63 |
+
cfg["testing"]["save_csv"] = False
|
64 |
+
|
65 |
+
if args.save_json.lower() == "true":
|
66 |
+
cfg["testing"]["save_json"] = True
|
67 |
+
else:
|
68 |
+
cfg["testing"]["save_json"] = False
|
69 |
+
|
70 |
+
return cfg
|
71 |
+
|
72 |
+
|
73 |
+
if __name__ == "__main__":
|
74 |
+
args = get_args()
|
75 |
+
cfg = load_args(cfg={"datasets": {}, "testing": {}, "models": {}}, args=args)
|
76 |
+
print(cfg)
|
77 |
+
model = JsonScoreEvaluator(cfg=cfg)
|
78 |
+
model.main()
|
evaluation/json_score.py
ADDED
@@ -0,0 +1,98 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# -----------------------------------------------------------------------------
|
2 |
+
# Do Not Alter This File!
|
3 |
+
# -----------------------------------------------------------------------------
|
4 |
+
# The following code is part of the logic used for loading and evaluating your
|
5 |
+
# output scores. Please DO NOT modify this section, as upon your submission,
|
6 |
+
# the whole evaluation logic will be overwritten by the original code.
|
7 |
+
# -----------------------------------------------------------------------------
|
8 |
+
|
9 |
+
import warnings
|
10 |
+
import numpy as np
|
11 |
+
import torch
|
12 |
+
from tqdm import tqdm
|
13 |
+
|
14 |
+
from evaluation.base_eval import BaseEval
|
15 |
+
from evaluation.utils.json_helpers import json_to_dict
|
16 |
+
|
17 |
+
warnings.filterwarnings("ignore")
|
18 |
+
|
19 |
+
|
20 |
+
class JsonScoreEvaluator(BaseEval):
|
21 |
+
"""
|
22 |
+
Evaluates anomaly detection performance based on pre-computed scores stored in JSON files.
|
23 |
+
|
24 |
+
This class extends the BaseEval class and specializes in reading scores from JSON files,
|
25 |
+
computing evaluation metrics, and optionally saving results to CSV or JSON format.
|
26 |
+
|
27 |
+
Notes:
|
28 |
+
- Score files are expected to follow the exact dataset structure.
|
29 |
+
`{category}/{split}/{anomaly_type}/{image_name}_scores.json`
|
30 |
+
e.g., `photovoltaic_module/test/good/037_scores.json`
|
31 |
+
- Score files are expected to be at `self.scores_dir`.
|
32 |
+
|
33 |
+
Example usage:
|
34 |
+
>>> evaluator = JsonScoreEvaluator(cfg)
|
35 |
+
>>> results = evaluator.main()
|
36 |
+
"""
|
37 |
+
|
38 |
+
def __init__(self, cfg):
|
39 |
+
super().__init__(cfg)
|
40 |
+
|
41 |
+
def get_scores_for_image(self, image_path):
|
42 |
+
image_scores_path = self.get_scores_path_for_image(image_path)
|
43 |
+
image_scores = json_to_dict(image_scores_path)
|
44 |
+
|
45 |
+
return image_scores
|
46 |
+
|
47 |
+
def load_category_scores(self, category):
|
48 |
+
cls_scores_list = [] # image level prediction
|
49 |
+
anomaly_maps = [] # pixel level prediction
|
50 |
+
gt_list = [] # image level ground truth
|
51 |
+
img_masks = [] # pixel level ground truth
|
52 |
+
|
53 |
+
image_path_list = []
|
54 |
+
test_dataset = self.load_datasets(category)
|
55 |
+
test_dataloader = torch.utils.data.DataLoader(
|
56 |
+
test_dataset,
|
57 |
+
batch_size=1,
|
58 |
+
shuffle=False,
|
59 |
+
num_workers=0,
|
60 |
+
pin_memory=True,
|
61 |
+
)
|
62 |
+
|
63 |
+
for image_info in tqdm(test_dataloader):
|
64 |
+
if not isinstance(image_info, dict):
|
65 |
+
raise ValueError("Encountered non-dict image in dataloader")
|
66 |
+
|
67 |
+
del image_info["image"]
|
68 |
+
|
69 |
+
image_path = image_info["image_path"][0]
|
70 |
+
image_path_list.extend(image_path)
|
71 |
+
|
72 |
+
img_masks.append(image_info["mask"])
|
73 |
+
gt_list.extend(list(image_info["is_anomaly"].numpy()))
|
74 |
+
|
75 |
+
image_scores = self.get_scores_for_image(image_path)
|
76 |
+
cls_scores = image_scores["img_level_score"]
|
77 |
+
anomaly_maps_iter = image_scores["pix_level_score"]
|
78 |
+
|
79 |
+
cls_scores_list.append(cls_scores)
|
80 |
+
anomaly_maps.append(anomaly_maps_iter)
|
81 |
+
|
82 |
+
pr_sp = np.array(cls_scores_list)
|
83 |
+
gt_sp = np.array(gt_list)
|
84 |
+
pr_px = np.array(anomaly_maps)
|
85 |
+
gt_px = torch.cat(img_masks, dim=0).numpy().astype(np.int32)
|
86 |
+
print(pr_px.shape)
|
87 |
+
assert pr_px.shape[1:] == (
|
88 |
+
1,
|
89 |
+
224,
|
90 |
+
224,
|
91 |
+
), "Predicted output scores do not meet the expected shape!"
|
92 |
+
assert gt_px.shape[1:] == (
|
93 |
+
1,
|
94 |
+
224,
|
95 |
+
224,
|
96 |
+
), "Loaded ground truth maps do not meet the expected shape!"
|
97 |
+
|
98 |
+
return gt_sp, pr_sp, gt_px, pr_px, image_path_list
|
evaluation/utils/json_helpers.py
ADDED
@@ -0,0 +1,46 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# -----------------------------------------------------------------------------
|
2 |
+
# Do Not Alter This File!
|
3 |
+
# -----------------------------------------------------------------------------
|
4 |
+
# The following code is part of the logic used for loading and evaluating your
|
5 |
+
# output scores. Please DO NOT modify this section, as upon your submission,
|
6 |
+
# the whole evaluation logic will be overwritten by the original code.
|
7 |
+
# -----------------------------------------------------------------------------
|
8 |
+
|
9 |
+
import json
|
10 |
+
import numpy as np
|
11 |
+
|
12 |
+
|
13 |
+
class NumpyEncoder(json.JSONEncoder):
|
14 |
+
"""Special json encoder for numpy types"""
|
15 |
+
|
16 |
+
def default(self, obj):
|
17 |
+
if isinstance(obj, np.integer):
|
18 |
+
return int(obj)
|
19 |
+
elif isinstance(obj, np.floating):
|
20 |
+
return float(obj)
|
21 |
+
elif isinstance(obj, np.ndarray):
|
22 |
+
return {
|
23 |
+
"__ndarray__": obj.tolist(),
|
24 |
+
"dtype": str(obj.dtype),
|
25 |
+
"shape": obj.shape,
|
26 |
+
}
|
27 |
+
else:
|
28 |
+
return super(NumpyEncoder, self).default(obj)
|
29 |
+
|
30 |
+
|
31 |
+
def dict_to_json(dct, filename):
|
32 |
+
"""Save a dictionary to a JSON file"""
|
33 |
+
with open(filename, "w") as f:
|
34 |
+
json.dump(dct, f, cls=NumpyEncoder)
|
35 |
+
|
36 |
+
|
37 |
+
def json_to_dict(filename):
|
38 |
+
"""Load a JSON file and convert it back to a dictionary of NumPy arrays"""
|
39 |
+
with open(filename, "r") as f:
|
40 |
+
dct = json.load(f)
|
41 |
+
|
42 |
+
for k, v in dct.items():
|
43 |
+
if isinstance(v, dict) and "__ndarray__" in v:
|
44 |
+
dct[k] = np.array(v["__ndarray__"], dtype=v["dtype"]).reshape(v["shape"])
|
45 |
+
|
46 |
+
return dct
|
evaluation/utils/metrics.py
ADDED
@@ -0,0 +1,111 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# -----------------------------------------------------------------------------
|
2 |
+
# Do Not Alter This File!
|
3 |
+
# -----------------------------------------------------------------------------
|
4 |
+
# The following code is part of the logic used for loading and evaluating your
|
5 |
+
# output scores. Please DO NOT modify this section, as upon your submission,
|
6 |
+
# the whole evaluation logic will be overwritten by the original code.
|
7 |
+
# -----------------------------------------------------------------------------
|
8 |
+
|
9 |
+
import numpy as np
|
10 |
+
from sklearn.metrics import (
|
11 |
+
auc,
|
12 |
+
roc_auc_score,
|
13 |
+
average_precision_score,
|
14 |
+
precision_recall_curve,
|
15 |
+
)
|
16 |
+
from skimage import measure
|
17 |
+
import warnings
|
18 |
+
|
19 |
+
# ref: https://github.com/gudovskiy/cflow-ad/blob/master/train.py
|
20 |
+
def cal_pro_score(masks, amaps, max_step=200, expect_fpr=0.3):
|
21 |
+
binary_amaps = np.zeros_like(amaps, dtype=bool)
|
22 |
+
min_th, max_th = amaps.min(), amaps.max()
|
23 |
+
delta = (max_th - min_th) / max_step
|
24 |
+
pros, fprs, ths = [], [], []
|
25 |
+
for th in np.arange(min_th, max_th, delta):
|
26 |
+
binary_amaps[amaps <= th], binary_amaps[amaps > th] = 0, 1
|
27 |
+
pro = []
|
28 |
+
for binary_amap, mask in zip(binary_amaps, masks):
|
29 |
+
for region in measure.regionprops(measure.label(mask)):
|
30 |
+
tp_pixels = binary_amap[region.coords[:, 0], region.coords[:, 1]].sum()
|
31 |
+
pro.append(tp_pixels / region.area)
|
32 |
+
inverse_masks = 1 - masks
|
33 |
+
fp_pixels = np.logical_and(inverse_masks, binary_amaps).sum()
|
34 |
+
fpr = fp_pixels / inverse_masks.sum()
|
35 |
+
pros.append(np.array(pro).mean())
|
36 |
+
fprs.append(fpr)
|
37 |
+
ths.append(th)
|
38 |
+
pros, fprs, ths = np.array(pros), np.array(fprs), np.array(ths)
|
39 |
+
idxes = fprs < expect_fpr
|
40 |
+
fprs = fprs[idxes]
|
41 |
+
print("fprs: ", fprs)
|
42 |
+
fprs = (fprs - fprs.min()) / (fprs.max() - fprs.min())
|
43 |
+
pro_auc = auc(fprs, pros[idxes])
|
44 |
+
return pro_auc
|
45 |
+
|
46 |
+
|
47 |
+
def compute_metrics(gt_sp=None, pr_sp=None, gt_px=None, pr_px=None):
|
48 |
+
# classification
|
49 |
+
if (
|
50 |
+
gt_sp is None
|
51 |
+
or pr_sp is None
|
52 |
+
or gt_sp.sum() == 0
|
53 |
+
or gt_sp.sum() == gt_sp.shape[0]
|
54 |
+
):
|
55 |
+
auroc_sp, f1_sp, ap_sp = 0, 0, 0
|
56 |
+
else:
|
57 |
+
auroc_sp = roc_auc_score(gt_sp, pr_sp)
|
58 |
+
ap_sp = average_precision_score(gt_sp, pr_sp)
|
59 |
+
precisions, recalls, thresholds = precision_recall_curve(gt_sp, pr_sp)
|
60 |
+
f1_scores = (2 * precisions * recalls) / (precisions + recalls)
|
61 |
+
f1_sp = np.max(f1_scores[np.isfinite(f1_scores)])
|
62 |
+
|
63 |
+
# segmentation
|
64 |
+
if gt_px is None or pr_px is None or gt_px.sum() == 0:
|
65 |
+
auroc_px, f1_px, ap_px, aupro = 0, 0, 0, 0
|
66 |
+
else:
|
67 |
+
auroc_px = roc_auc_score(gt_px.ravel(), pr_px.ravel())
|
68 |
+
ap_px = average_precision_score(gt_px.ravel(), pr_px.ravel())
|
69 |
+
precisions, recalls, thresholds = precision_recall_curve(
|
70 |
+
gt_px.ravel(), pr_px.ravel()
|
71 |
+
)
|
72 |
+
f1_scores = (2 * precisions * recalls) / (precisions + recalls)
|
73 |
+
f1_px = np.max(f1_scores[np.isfinite(f1_scores)])
|
74 |
+
aupro = cal_pro_score(gt_px.squeeze(), pr_px.squeeze())
|
75 |
+
|
76 |
+
image_metric = [auroc_sp, f1_sp, ap_sp]
|
77 |
+
pixel_metric = [auroc_px, f1_px, ap_px, aupro]
|
78 |
+
|
79 |
+
return image_metric, pixel_metric
|
80 |
+
|
81 |
+
def compute_auroc(labels, scores):
|
82 |
+
"""
|
83 |
+
Computes the Area Under the Receiver Operating Characteristic Curve (AUROC).
|
84 |
+
|
85 |
+
Args:
|
86 |
+
labels (list or np.ndarray): True binary labels (0 for normal, 1 for anomaly).
|
87 |
+
scores (list or np.ndarray): Predicted scores or probabilities for the positive class.
|
88 |
+
|
89 |
+
Returns:
|
90 |
+
float: AUROC score. Returns None if AUROC is undefined.
|
91 |
+
"""
|
92 |
+
# Convert inputs to numpy arrays
|
93 |
+
labels = np.array(labels)
|
94 |
+
scores = np.array(scores)
|
95 |
+
|
96 |
+
# Ensure that labels are binary
|
97 |
+
unique_labels = np.unique(labels)
|
98 |
+
if set(unique_labels) != {0, 1}:
|
99 |
+
raise ValueError(f"Labels must be binary (0 and 1). Found labels: {unique_labels}")
|
100 |
+
|
101 |
+
# Check if both classes are present
|
102 |
+
if len(unique_labels) < 2:
|
103 |
+
warnings.warn("Only one class present in labels. AUROC is undefined.")
|
104 |
+
return None
|
105 |
+
|
106 |
+
try:
|
107 |
+
auroc = roc_auc_score(labels, scores)
|
108 |
+
return auroc
|
109 |
+
except ValueError as e:
|
110 |
+
warnings.warn(f"Error computing AUROC: {e}")
|
111 |
+
return None
|
main.py
ADDED
@@ -0,0 +1,308 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# main.py
|
2 |
+
|
3 |
+
import os
|
4 |
+
import torch
|
5 |
+
from torch.utils.data import DataLoader
|
6 |
+
from datasets.all_classes_dataset import AllClassesDataset, DatasetSplit
|
7 |
+
from models.anomaly_detector import AnomalyDetector
|
8 |
+
from utils.dump_scores import DumpScores
|
9 |
+
import logging
|
10 |
+
import json
|
11 |
+
from sklearn.metrics import average_precision_score, roc_auc_score, f1_score
|
12 |
+
import numpy as np
|
13 |
+
import torch.nn.functional as F
|
14 |
+
import random
|
15 |
+
|
16 |
+
def set_seed(seed: int):
|
17 |
+
"""
|
18 |
+
Set the seed for reproducibility across various libraries.
|
19 |
+
|
20 |
+
Args:
|
21 |
+
seed (int): The seed value to be set.
|
22 |
+
"""
|
23 |
+
random.seed(seed)
|
24 |
+
np.random.seed(seed)
|
25 |
+
torch.manual_seed(seed)
|
26 |
+
|
27 |
+
if torch.cuda.is_available():
|
28 |
+
torch.cuda.manual_seed(seed)
|
29 |
+
torch.cuda.manual_seed_all(seed) # For multi-GPU setups
|
30 |
+
|
31 |
+
# Ensure deterministic behavior in PyTorch
|
32 |
+
torch.backends.cudnn.deterministic = True
|
33 |
+
torch.backends.cudnn.benchmark = False
|
34 |
+
|
35 |
+
# For DataLoader workers
|
36 |
+
os.environ['PYTHONHASHSEED'] = str(seed)
|
37 |
+
|
38 |
+
def worker_init_fn(worker_id):
|
39 |
+
"""
|
40 |
+
Initialize the seed for each DataLoader worker to ensure reproducibility.
|
41 |
+
|
42 |
+
Args:
|
43 |
+
worker_id (int): The worker ID.
|
44 |
+
"""
|
45 |
+
seed = torch.initial_seed()
|
46 |
+
np.random.seed(seed % 2**32)
|
47 |
+
random.seed(seed % 2**32)
|
48 |
+
|
49 |
+
def compute_aupro(y_true_pixel, y_scores_pixel, num_thresholds=50):
|
50 |
+
"""
|
51 |
+
Compute Area Under the Per-Region Overlap Curve (AUPRO).
|
52 |
+
|
53 |
+
Args:
|
54 |
+
y_true_pixel (np.ndarray): Ground truth binary masks, shape [N, H, W]
|
55 |
+
y_scores_pixel (np.ndarray): Predicted anomaly scores, shape [N, H, W]
|
56 |
+
num_thresholds (int): Number of thresholds to evaluate.
|
57 |
+
|
58 |
+
Returns:
|
59 |
+
float: AUPRO score.
|
60 |
+
"""
|
61 |
+
# Define thresholds
|
62 |
+
thresholds = np.linspace(0, 1, num_thresholds)
|
63 |
+
|
64 |
+
# Initialize list to store overlaps
|
65 |
+
overlaps = []
|
66 |
+
|
67 |
+
for thresh in thresholds:
|
68 |
+
# Binarize predictions
|
69 |
+
y_pred = (y_scores_pixel >= thresh).astype(int)
|
70 |
+
|
71 |
+
# Compute Intersection over Union (IoU) for each sample
|
72 |
+
ious = []
|
73 |
+
for gt, pred in zip(y_true_pixel, y_pred):
|
74 |
+
intersection = np.logical_and(gt, pred).sum()
|
75 |
+
union = np.logical_or(gt, pred).sum()
|
76 |
+
if union == 0:
|
77 |
+
iou = 1.0 # If both gt and pred are all zeros
|
78 |
+
else:
|
79 |
+
iou = intersection / union
|
80 |
+
ious.append(iou)
|
81 |
+
|
82 |
+
# Average IoU over all samples
|
83 |
+
avg_iou = np.mean(ious)
|
84 |
+
overlaps.append(avg_iou)
|
85 |
+
|
86 |
+
# Compute the area under the overlap curve
|
87 |
+
aupro = np.trapz(overlaps, thresholds) / np.trapz([1] * len(thresholds), thresholds) # Normalize
|
88 |
+
return aupro
|
89 |
+
|
90 |
+
|
91 |
+
def compute_metrics(y_true_image, y_scores_image, y_true_pixel, y_scores_pixel):
|
92 |
+
"""
|
93 |
+
Compute the required metrics based on true labels and predicted scores.
|
94 |
+
|
95 |
+
Args:
|
96 |
+
y_true_image (np.ndarray): Ground truth image labels, shape [N]
|
97 |
+
y_scores_image (np.ndarray): Predicted image scores, shape [N]
|
98 |
+
y_true_pixel (np.ndarray): Ground truth pixel masks, shape [N, H, W]
|
99 |
+
y_scores_pixel (np.ndarray): Predicted pixel anomaly scores, shape [N, H, W]
|
100 |
+
|
101 |
+
Returns:
|
102 |
+
dict: Dictionary containing computed metrics.
|
103 |
+
"""
|
104 |
+
# Check image-level consistency
|
105 |
+
if len(y_true_image) != len(y_scores_image):
|
106 |
+
raise ValueError(f"Image-level y_true and y_scores have different lengths: {len(y_true_image)} vs {len(y_scores_image)}")
|
107 |
+
|
108 |
+
# Check pixel-level consistency
|
109 |
+
if y_true_pixel.shape != y_scores_pixel.shape:
|
110 |
+
raise ValueError(f"Pixel-level y_true and y_scores have different shapes: {y_true_pixel.shape} vs {y_scores_pixel.shape}")
|
111 |
+
|
112 |
+
# Image-level Metrics
|
113 |
+
image_ap = average_precision_score(y_true_image, y_scores_image)
|
114 |
+
image_auroc = roc_auc_score(y_true_image, y_scores_image)
|
115 |
+
y_pred_image = (y_scores_image >= 0.5).astype(int)
|
116 |
+
image_f1 = f1_score(y_true_image, y_pred_image)
|
117 |
+
|
118 |
+
# Pixel-level Metrics
|
119 |
+
pixel_ap = average_precision_score(y_true_pixel.flatten(), y_scores_pixel.flatten())
|
120 |
+
pixel_auroc = roc_auc_score(y_true_pixel.flatten(), y_scores_pixel.flatten())
|
121 |
+
pixel_aupro = compute_aupro(y_true_pixel, y_scores_pixel)
|
122 |
+
y_pred_pixel = (y_scores_pixel >= 0.5).astype(int)
|
123 |
+
pixel_f1 = f1_score(y_true_pixel.flatten(), y_pred_pixel.flatten())
|
124 |
+
|
125 |
+
# Compute leaderboard_score as a weighted average (example weights)
|
126 |
+
# Adjust weights as per your specific requirements
|
127 |
+
leaderboard_score = (
|
128 |
+
0.25 * image_auroc +
|
129 |
+
0.25 * image_f1 +
|
130 |
+
0.25 * pixel_auroc +
|
131 |
+
0.25 * pixel_f1
|
132 |
+
)
|
133 |
+
|
134 |
+
metrics = {
|
135 |
+
"image_metrics": {
|
136 |
+
"image_ap": round(float(image_ap), 4),
|
137 |
+
"image_auroc": round(float(image_auroc), 4),
|
138 |
+
"image_f1": round(float(image_f1), 4)
|
139 |
+
},
|
140 |
+
"pixel_metrics": {
|
141 |
+
"pixel_ap": round(float(pixel_ap), 4),
|
142 |
+
"pixel_aupro": round(float(pixel_aupro), 4),
|
143 |
+
"pixel_auroc": round(float(pixel_auroc), 4),
|
144 |
+
"pixel_f1": round(float(pixel_f1), 4)
|
145 |
+
},
|
146 |
+
"overall_metric": {
|
147 |
+
"leaderboard_score": round(float(leaderboard_score), 4)
|
148 |
+
}
|
149 |
+
}
|
150 |
+
|
151 |
+
return metrics
|
152 |
+
|
153 |
+
|
154 |
+
def get_class_name(image_path, source_dir):
|
155 |
+
"""
|
156 |
+
Extract the class name from the image path.
|
157 |
+
|
158 |
+
Args:
|
159 |
+
image_path (str): Path to the image file.
|
160 |
+
source_dir (str): Root source directory.
|
161 |
+
|
162 |
+
Returns:
|
163 |
+
str: Class name.
|
164 |
+
"""
|
165 |
+
# Example image_path: "./data/pill/test/broken/image1.png"
|
166 |
+
rel_path = os.path.relpath(image_path, source_dir) # "pill/test/broken/image1.png"
|
167 |
+
parts = rel_path.split(os.sep)
|
168 |
+
if len(parts) < 2:
|
169 |
+
raise ValueError(f"Unexpected image path format: {image_path}")
|
170 |
+
class_name = parts[0] # "pill"
|
171 |
+
return class_name
|
172 |
+
|
173 |
+
|
174 |
+
def main():
|
175 |
+
SEED = 41 # You can choose any integer value
|
176 |
+
set_seed(SEED)
|
177 |
+
# Configure logging
|
178 |
+
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
|
179 |
+
|
180 |
+
# Configuration
|
181 |
+
source_dir = "./data"
|
182 |
+
output_scores_dir = "./output_scores"
|
183 |
+
split = DatasetSplit.TEST # Use the Enum instead of string
|
184 |
+
device = "cuda:0" if torch.cuda.is_available() else "cpu"
|
185 |
+
|
186 |
+
logging.info("Initializing the dataset and dataloader...")
|
187 |
+
|
188 |
+
# Initialize dataset and dataloader using AllClassesDataset with output_size=17
|
189 |
+
dataset = AllClassesDataset(
|
190 |
+
source=source_dir,
|
191 |
+
split=split,
|
192 |
+
# output_size=16 # Set to match anomaly_map resolution
|
193 |
+
)
|
194 |
+
dataloader = DataLoader(dataset, batch_size=4, shuffle=False, num_workers=0)
|
195 |
+
|
196 |
+
logging.info("Initializing the anomaly detector...")
|
197 |
+
# Initialize anomaly detector
|
198 |
+
detector = AnomalyDetector(device=device)
|
199 |
+
|
200 |
+
# Initialize DumpScores
|
201 |
+
dump_scores = DumpScores(output_dir=output_scores_dir)
|
202 |
+
|
203 |
+
logging.info("Starting anomaly detection inference...")
|
204 |
+
# Initialize containers for metrics
|
205 |
+
classes = dataset.get_all_class_names()
|
206 |
+
metrics_data = {cls: {
|
207 |
+
"y_true_image": [],
|
208 |
+
"y_scores_image": [],
|
209 |
+
"y_true_pixel": [],
|
210 |
+
"y_scores_pixel": []
|
211 |
+
} for cls in classes}
|
212 |
+
|
213 |
+
# Iterate through the dataset
|
214 |
+
for batch_idx, batch in enumerate(dataloader):
|
215 |
+
image = batch['image'].squeeze(0) # Shape: [3, H, W]
|
216 |
+
mask = batch['mask'].squeeze(1).numpy() # Remove all singleton dimensions to get [17, 17]
|
217 |
+
image_label = batch['is_anomaly'].item() # 1 or 0
|
218 |
+
image_path = batch['image_path'][0] # Assuming batch_size=1
|
219 |
+
|
220 |
+
# Extract class name from image_path
|
221 |
+
try:
|
222 |
+
class_name = get_class_name(image_path, source_dir)
|
223 |
+
except ValueError as e:
|
224 |
+
logging.error(f"Error extracting class name: {e}")
|
225 |
+
continue # Skip this sample
|
226 |
+
|
227 |
+
# Extract features and compute scores using GLASS
|
228 |
+
image_score, anomaly_map = detector.extract_features(image, "all")
|
229 |
+
|
230 |
+
# Compute pixel-level anomaly score (already normalized)
|
231 |
+
pixel_score = detector.compute_pixel_score(anomaly_map).squeeze()
|
232 |
+
|
233 |
+
pixel_score_tensor = torch.from_numpy(pixel_score).float().unsqueeze(0).unsqueeze(0).to(
|
234 |
+
device) # Shape: [1, 1, 17, 17]
|
235 |
+
|
236 |
+
# **Upsample pixel_score to (224, 224)**
|
237 |
+
# Option 1: Using PyTorch Interpolation
|
238 |
+
pixel_score = F.interpolate(
|
239 |
+
pixel_score_tensor, # Add batch and channel dimensions
|
240 |
+
size=(224, 224),
|
241 |
+
mode='bilinear',
|
242 |
+
align_corners=False
|
243 |
+
).squeeze(0).cpu().numpy() # Removes all singleton dimensions, resulting in [224, 224]
|
244 |
+
|
245 |
+
|
246 |
+
# Option 2: Using OpenCV (Uncomment if preferred)
|
247 |
+
# pixel_score_np = pixel_score.numpy()
|
248 |
+
# pixel_score = cv2.resize(
|
249 |
+
# pixel_score,
|
250 |
+
# dsize=(224, 224),
|
251 |
+
# interpolation=cv2.INTER_LINEAR
|
252 |
+
# )
|
253 |
+
|
254 |
+
# **Optional: Verify the upsampled pixel_score shape**
|
255 |
+
# if pixel_score.shape != (1, 224, 224):
|
256 |
+
# logging.warning(
|
257 |
+
# f"Upsampled pixel score shape mismatch for image {image_path}: expected (224, 224), got {pixel_score.shape}")
|
258 |
+
# continue # Skip this sample
|
259 |
+
|
260 |
+
# Append to metrics_data
|
261 |
+
metrics_data[class_name]["y_true_image"].append(image_label)
|
262 |
+
metrics_data[class_name]["y_scores_image"].append(image_score)
|
263 |
+
metrics_data[class_name]["y_true_pixel"].append(mask)
|
264 |
+
metrics_data[class_name]["y_scores_pixel"].append(pixel_score)
|
265 |
+
|
266 |
+
# Save individual image scores
|
267 |
+
dump_scores.save_scores([image_path], [image_score], [pixel_score])
|
268 |
+
|
269 |
+
logging.info(f"[{batch_idx + 1}/{len(dataloader)}] Processed image: {image_path}")
|
270 |
+
logging.info(f"Image-level score: {image_score:.4f}")
|
271 |
+
logging.info(f"Pixel-level mean score: {pixel_score.mean():.4f}")
|
272 |
+
|
273 |
+
logging.info("Anomaly detection inference completed. Computing metrics...")
|
274 |
+
|
275 |
+
# Initialize dictionary to hold metrics per class
|
276 |
+
classes_metrics = {}
|
277 |
+
|
278 |
+
for cls in classes:
|
279 |
+
y_true_image = np.array(metrics_data[cls]["y_true_image"])
|
280 |
+
y_scores_image = np.array(metrics_data[cls]["y_scores_image"])
|
281 |
+
y_true_pixel = np.array(metrics_data[cls]["y_true_pixel"])
|
282 |
+
y_scores_pixel = np.array(metrics_data[cls]["y_scores_pixel"])
|
283 |
+
|
284 |
+
# Check if there are any samples for the class
|
285 |
+
if len(y_true_image) == 0:
|
286 |
+
logging.warning(f"No samples found for class {cls}. Skipping metric computation.")
|
287 |
+
continue
|
288 |
+
|
289 |
+
try:
|
290 |
+
metrics = compute_metrics(y_true_image, y_scores_image, y_true_pixel, y_scores_pixel)
|
291 |
+
classes_metrics[cls] = metrics
|
292 |
+
logging.info(f"Metrics computed for class: {cls}")
|
293 |
+
except Exception as e:
|
294 |
+
logging.error(f"Failed to compute metrics for class {cls}: {e}")
|
295 |
+
|
296 |
+
# Save metrics to JSON
|
297 |
+
os.makedirs(output_scores_dir, exist_ok=True)
|
298 |
+
metrics_json_path = os.path.join(output_scores_dir, "metrics.json")
|
299 |
+
try:
|
300 |
+
with open(metrics_json_path, "w") as f:
|
301 |
+
json.dump(classes_metrics, f, indent=4)
|
302 |
+
logging.info(f"Metrics successfully saved to {metrics_json_path}")
|
303 |
+
except Exception as e:
|
304 |
+
logging.error(f"Failed to save metrics to {metrics_json_path}: {e}")
|
305 |
+
|
306 |
+
|
307 |
+
if __name__ == "__main__":
|
308 |
+
main()
|
models/anomaly_detector.py
ADDED
@@ -0,0 +1,186 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# models/anomaly_detector.py
|
2 |
+
|
3 |
+
import torch
|
4 |
+
import torch.nn.functional as F
|
5 |
+
import numpy as np
|
6 |
+
from .glass import GLASS # Ensure correct import
|
7 |
+
import os
|
8 |
+
import logging
|
9 |
+
from torchvision import models
|
10 |
+
|
11 |
+
LOGGER = logging.getLogger(__name__)
|
12 |
+
|
13 |
+
class AnomalyDetector:
|
14 |
+
def __init__(self, device='cuda'):
|
15 |
+
self.device = torch.device(device if torch.cuda.is_available() else 'cpu')
|
16 |
+
|
17 |
+
# Initialize the backbone (e.g., ResNet-50) without pretrained weights
|
18 |
+
backbone = models.resnet50(pretrained=False)
|
19 |
+
|
20 |
+
# Load backbone weights from local file
|
21 |
+
backbone_weights_path = './backbones/resnet50_backbone.pth' # Update this path as needed
|
22 |
+
if os.path.exists(backbone_weights_path):
|
23 |
+
LOGGER.info(f"Loading ResNet-50 backbone weights from '{backbone_weights_path}'")
|
24 |
+
checkpoint = torch.load(backbone_weights_path, map_location="cpu")
|
25 |
+
try:
|
26 |
+
backbone.load_state_dict(checkpoint, strict=True)
|
27 |
+
LOGGER.info("ResNet-50 backbone weights loaded successfully.")
|
28 |
+
except RuntimeError as e:
|
29 |
+
LOGGER.error(f"Error loading ResNet-50 backbone state_dict: {e}")
|
30 |
+
raise
|
31 |
+
else:
|
32 |
+
LOGGER.error(f"Backbone weights not found at '{backbone_weights_path}'")
|
33 |
+
raise FileNotFoundError(f"Backbone weights not found at '{backbone_weights_path}'")
|
34 |
+
|
35 |
+
# Initialize the GLASS model
|
36 |
+
self.glass = GLASS(device=self.device)
|
37 |
+
|
38 |
+
# Define parameters for GLASS.load() to match training
|
39 |
+
layers_to_extract_from = ['layer4'] # Extract only the last layer
|
40 |
+
input_shape = (3, 224, 224) # Match training input shape
|
41 |
+
pretrain_embed_dimension = 2048 # Corrected dimension for 'layer4' in ResNet-50
|
42 |
+
target_embed_dimension = 1024 # Match training target dimension
|
43 |
+
|
44 |
+
# Initialize GLASS with consistent parameters
|
45 |
+
self.glass.load(
|
46 |
+
backbone=backbone,
|
47 |
+
layers_to_extract_from=layers_to_extract_from,
|
48 |
+
device=self.device,
|
49 |
+
input_shape=input_shape,
|
50 |
+
pretrain_embed_dimension=pretrain_embed_dimension,
|
51 |
+
target_embed_dimension=target_embed_dimension,
|
52 |
+
patchsize=3,
|
53 |
+
patchstride=1,
|
54 |
+
meta_epochs=640, # Not relevant for inference but required by load method
|
55 |
+
eval_epochs=1,
|
56 |
+
dsc_layers=2,
|
57 |
+
dsc_hidden=1024,
|
58 |
+
dsc_margin=0.5,
|
59 |
+
train_backbone=False,
|
60 |
+
pre_proj=1,
|
61 |
+
mining=1,
|
62 |
+
noise=0.015,
|
63 |
+
radius=0.75,
|
64 |
+
p=0.5,
|
65 |
+
lr=0.0001,
|
66 |
+
svd=0,
|
67 |
+
step=20,
|
68 |
+
limit=392,
|
69 |
+
**{}
|
70 |
+
)
|
71 |
+
|
72 |
+
# Set model directories
|
73 |
+
model_dir = "./models" # Base directory for models
|
74 |
+
dataset_name = "rayan_dataset" # Example dataset name
|
75 |
+
self.glass.set_model_dir(model_dir, dataset_name)
|
76 |
+
|
77 |
+
self.glass.to(self.device)
|
78 |
+
self.glass.eval() # Set GLASS to evaluation mode
|
79 |
+
|
80 |
+
# Initialize a cache to keep track of loaded classes
|
81 |
+
self.loaded_classes = set()
|
82 |
+
|
83 |
+
def load_model_weights(self, model_dir, classname):
|
84 |
+
"""
|
85 |
+
Load the saved model weights for a specific class.
|
86 |
+
|
87 |
+
Args:
|
88 |
+
model_dir (str): Base directory where models are saved.
|
89 |
+
classname (str): The class name whose model weights to load.
|
90 |
+
"""
|
91 |
+
checkpoint_path = os.path.join(model_dir, classname, f"best_model_{classname}.pth")
|
92 |
+
if os.path.exists(checkpoint_path):
|
93 |
+
LOGGER.info(f"Loading model weights from '{checkpoint_path}' for class '{classname}'")
|
94 |
+
checkpoint = torch.load(checkpoint_path, map_location=self.device)
|
95 |
+
try:
|
96 |
+
self.glass.load_state_dict(checkpoint, strict=True)
|
97 |
+
LOGGER.info(f"Model weights loaded successfully for class '{classname}'")
|
98 |
+
except RuntimeError as e:
|
99 |
+
LOGGER.error(f"Error loading state_dict for class '{classname}': {e}")
|
100 |
+
raise
|
101 |
+
else:
|
102 |
+
LOGGER.error(f"Checkpoint not found at '{checkpoint_path}' for class '{classname}'")
|
103 |
+
raise FileNotFoundError(f"Checkpoint not found at '{checkpoint_path}' for class '{classname}'")
|
104 |
+
|
105 |
+
def extract_features(self, image, classname):
|
106 |
+
"""
|
107 |
+
Use GLASS to extract features and generate anomaly scores for a specific class.
|
108 |
+
|
109 |
+
Args:
|
110 |
+
image (torch.Tensor): Image tensor of shape [3, H, W]
|
111 |
+
classname (str): The class name for which to perform anomaly detection.
|
112 |
+
|
113 |
+
Returns:
|
114 |
+
tuple: (image_score, anomaly_map)
|
115 |
+
"""
|
116 |
+
|
117 |
+
# Load model weights for classname if not already loaded
|
118 |
+
# if classname not in self.loaded_classes:
|
119 |
+
# try:
|
120 |
+
# self.load_model_weights(model_dir="./models", classname=classname)
|
121 |
+
# self.loaded_classes.add(classname)
|
122 |
+
# except FileNotFoundError as e:
|
123 |
+
# LOGGER.error(f"Failed to load model weights for class '{classname}': {e}")
|
124 |
+
# raise
|
125 |
+
|
126 |
+
# Reshape image to include batch dimension
|
127 |
+
image = image.unsqueeze(0).to(self.device) # Shape: [1, 3, H, W]
|
128 |
+
|
129 |
+
# Use GLASS to get embeddings
|
130 |
+
with torch.no_grad():
|
131 |
+
patch_features, patch_shapes = self.glass._embed(image, evaluation=True)
|
132 |
+
if self.glass.pre_proj > 0:
|
133 |
+
patch_features = self.glass.pre_projection(patch_features)
|
134 |
+
# Handle if pre_projection returns multiple outputs
|
135 |
+
if isinstance(patch_features, tuple) or isinstance(patch_features, list):
|
136 |
+
patch_features = patch_features[0]
|
137 |
+
|
138 |
+
# Pass through discriminator to get anomaly scores
|
139 |
+
patch_scores = self.glass.discriminator(patch_features)
|
140 |
+
patch_scores = self.glass.patch_maker.unpatch_scores(patch_scores, batchsize=image.shape[0])
|
141 |
+
|
142 |
+
# Select the last layer's patch_shapes (only one layer now)
|
143 |
+
last_patch_shape = patch_shapes[-1] # Should be [17, 17]
|
144 |
+
|
145 |
+
# Ensure that last_patch_shape is a list or tuple of two integers
|
146 |
+
if isinstance(last_patch_shape, (list, tuple)) and len(last_patch_shape) == 2:
|
147 |
+
# Reshape patch_scores to [batch_size, H_patches, W_patches]
|
148 |
+
# First, squeeze the last dimension
|
149 |
+
patch_scores = patch_scores.squeeze(-1) # Shape: [1, 289]
|
150 |
+
|
151 |
+
# Reshape to [1, 17, 17]
|
152 |
+
patch_scores = patch_scores.reshape(image.shape[0], *last_patch_shape) # [1, 17, 17]
|
153 |
+
else:
|
154 |
+
LOGGER.error(f"Unexpected patch_shapes format: {patch_shapes}")
|
155 |
+
raise ValueError(f"Unexpected patch_shapes format: {patch_shapes}")
|
156 |
+
|
157 |
+
# Compute image-level score (example: mean of patch scores)
|
158 |
+
image_score = patch_scores.mean().item()
|
159 |
+
|
160 |
+
# Anomaly map is the patch_scores itself, normalized
|
161 |
+
anomaly_map = patch_scores.cpu().numpy()
|
162 |
+
anomaly_map = np.clip(anomaly_map, 0, 1)
|
163 |
+
|
164 |
+
# Log anomaly map statistics for debugging
|
165 |
+
LOGGER.info(f"Anomaly map stats for class '{classname}': min={anomaly_map.min():.4f}, max={anomaly_map.max():.4f}, mean={anomaly_map.mean():.4f}")
|
166 |
+
|
167 |
+
return image_score, anomaly_map
|
168 |
+
|
169 |
+
def compute_pixel_score(self, anomaly_map):
|
170 |
+
"""
|
171 |
+
Processes the anomaly map for pixel-level evaluation.
|
172 |
+
|
173 |
+
Args:
|
174 |
+
anomaly_map (np.ndarray): Anomaly map of shape [17, 17]
|
175 |
+
|
176 |
+
Returns:
|
177 |
+
np.ndarray: Processed anomaly map of shape [17, 17]
|
178 |
+
"""
|
179 |
+
# Normalize anomaly_map to [0, 1]
|
180 |
+
min_val = anomaly_map.min()
|
181 |
+
max_val = anomaly_map.max()
|
182 |
+
if max_val - min_val < 1e-8:
|
183 |
+
LOGGER.warning("Anomaly map has zero variance. Returning zero map.")
|
184 |
+
return np.zeros_like(anomaly_map)
|
185 |
+
anomaly_map = (anomaly_map - min_val) / (max_val - min_val + 1e-8)
|
186 |
+
return anomaly_map
|
models/common.py
ADDED
@@ -0,0 +1,154 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# common.py
|
2 |
+
|
3 |
+
import copy
|
4 |
+
import numpy as np
|
5 |
+
import scipy.ndimage as ndimage
|
6 |
+
import torch
|
7 |
+
import torch.nn.functional as F
|
8 |
+
from torch import nn
|
9 |
+
|
10 |
+
|
11 |
+
class Preprocessing(torch.nn.Module):
|
12 |
+
def __init__(self, input_dims, output_dim):
|
13 |
+
super(Preprocessing, self).__init__()
|
14 |
+
self.input_dims = input_dims
|
15 |
+
self.output_dim = output_dim
|
16 |
+
|
17 |
+
self.preprocessing_modules = torch.nn.ModuleList()
|
18 |
+
for _ in input_dims:
|
19 |
+
module = MeanMapper(output_dim)
|
20 |
+
self.preprocessing_modules.append(module)
|
21 |
+
|
22 |
+
def forward(self, features):
|
23 |
+
_features = []
|
24 |
+
for module, feature in zip(self.preprocessing_modules, features):
|
25 |
+
_features.append(module(feature))
|
26 |
+
return torch.stack(_features, dim=1)
|
27 |
+
|
28 |
+
|
29 |
+
class MeanMapper(torch.nn.Module):
|
30 |
+
def __init__(self, preprocessing_dim):
|
31 |
+
super(MeanMapper, self).__init__()
|
32 |
+
self.preprocessing_dim = preprocessing_dim
|
33 |
+
|
34 |
+
def forward(self, features):
|
35 |
+
features = features.reshape(len(features), 1, -1)
|
36 |
+
return F.adaptive_avg_pool1d(features, self.preprocessing_dim).squeeze(1)
|
37 |
+
|
38 |
+
|
39 |
+
class Aggregator(torch.nn.Module):
|
40 |
+
def __init__(self, target_dim):
|
41 |
+
super(Aggregator, self).__init__()
|
42 |
+
self.target_dim = target_dim
|
43 |
+
|
44 |
+
def forward(self, features):
|
45 |
+
"""Returns reshaped and average pooled features."""
|
46 |
+
features = features.reshape(len(features), 1, -1)
|
47 |
+
features = F.adaptive_avg_pool1d(features, self.target_dim)
|
48 |
+
return features.reshape(len(features), -1)
|
49 |
+
|
50 |
+
|
51 |
+
class RescaleSegmentor:
|
52 |
+
def __init__(self, device, target_size=288):
|
53 |
+
self.device = device
|
54 |
+
self.target_size = target_size
|
55 |
+
self.smoothing = 4
|
56 |
+
|
57 |
+
def convert_to_segmentation(self, patch_scores):
|
58 |
+
with torch.no_grad():
|
59 |
+
if isinstance(patch_scores, np.ndarray):
|
60 |
+
patch_scores = torch.from_numpy(patch_scores)
|
61 |
+
_scores = patch_scores.to(self.device)
|
62 |
+
_scores = _scores.unsqueeze(1)
|
63 |
+
_scores = F.interpolate(
|
64 |
+
_scores, size=self.target_size, mode="bilinear", align_corners=False
|
65 |
+
)
|
66 |
+
_scores = _scores.squeeze(1)
|
67 |
+
patch_scores = _scores.cpu().numpy()
|
68 |
+
return [ndimage.gaussian_filter(patch_score, sigma=self.smoothing) for patch_score in patch_scores]
|
69 |
+
|
70 |
+
|
71 |
+
class NetworkFeatureAggregator(torch.nn.Module):
|
72 |
+
"""Efficient extraction of network features."""
|
73 |
+
|
74 |
+
def __init__(self, backbone, layers_to_extract_from, device, train_backbone=False):
|
75 |
+
super(NetworkFeatureAggregator, self).__init__()
|
76 |
+
"""Extraction of network features.
|
77 |
+
|
78 |
+
Runs a network only to the last layer of the list of layers where
|
79 |
+
network features should be extracted from.
|
80 |
+
|
81 |
+
Args:
|
82 |
+
backbone: torchvision.model
|
83 |
+
layers_to_extract_from: [list of str]
|
84 |
+
"""
|
85 |
+
self.layers_to_extract_from = layers_to_extract_from
|
86 |
+
self.backbone = backbone
|
87 |
+
self.device = device
|
88 |
+
self.train_backbone = train_backbone
|
89 |
+
if not hasattr(backbone, "hook_handles"):
|
90 |
+
self.backbone.hook_handles = []
|
91 |
+
for handle in self.backbone.hook_handles:
|
92 |
+
handle.remove()
|
93 |
+
self.outputs = {}
|
94 |
+
|
95 |
+
for extract_layer in layers_to_extract_from:
|
96 |
+
self.register_hook(extract_layer)
|
97 |
+
|
98 |
+
self.to(self.device)
|
99 |
+
|
100 |
+
def forward(self, images, eval=True):
|
101 |
+
self.outputs.clear()
|
102 |
+
if self.train_backbone and not eval:
|
103 |
+
self.backbone.train()
|
104 |
+
self.backbone(images)
|
105 |
+
else:
|
106 |
+
self.backbone.eval()
|
107 |
+
with torch.no_grad():
|
108 |
+
self.backbone(images)
|
109 |
+
return self.outputs
|
110 |
+
|
111 |
+
def feature_dimensions(self, input_shape):
|
112 |
+
"""Computes the feature dimensions for all layers given input_shape."""
|
113 |
+
_input = torch.ones([1] + list(input_shape)).to(self.device)
|
114 |
+
_output = self(_input)
|
115 |
+
return [_output[layer].shape[1] for layer in self.layers_to_extract_from]
|
116 |
+
|
117 |
+
def register_hook(self, layer_name):
|
118 |
+
module = self.find_module(self.backbone, layer_name)
|
119 |
+
if module is not None:
|
120 |
+
forward_hook = ForwardHook(self.outputs, layer_name, self.layers_to_extract_from[-1])
|
121 |
+
if isinstance(module, torch.nn.Sequential):
|
122 |
+
hook = module[-1].register_forward_hook(forward_hook)
|
123 |
+
else:
|
124 |
+
hook = module.register_forward_hook(forward_hook)
|
125 |
+
self.backbone.hook_handles.append(hook)
|
126 |
+
else:
|
127 |
+
raise ValueError(f"Module {layer_name} not found in the model")
|
128 |
+
|
129 |
+
def find_module(self, model, module_name):
|
130 |
+
for name, module in model.named_modules():
|
131 |
+
if name == module_name:
|
132 |
+
return module
|
133 |
+
elif '.' in module_name:
|
134 |
+
father, child = module_name.split('.', 1)
|
135 |
+
if name == father:
|
136 |
+
return self.find_module(module, child)
|
137 |
+
return None
|
138 |
+
|
139 |
+
|
140 |
+
class ForwardHook:
|
141 |
+
def __init__(self, hook_dict, layer_name: str, last_layer_to_extract: str):
|
142 |
+
self.hook_dict = hook_dict
|
143 |
+
self.layer_name = layer_name
|
144 |
+
self.raise_exception_to_break = copy.deepcopy(
|
145 |
+
layer_name == last_layer_to_extract
|
146 |
+
)
|
147 |
+
|
148 |
+
def __call__(self, module, input, output):
|
149 |
+
self.hook_dict[self.layer_name] = output
|
150 |
+
return None
|
151 |
+
|
152 |
+
|
153 |
+
class LastLayerToExtractReachedException(Exception):
|
154 |
+
pass
|
models/glass.py
ADDED
@@ -0,0 +1,372 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# models/glass.py
|
2 |
+
|
3 |
+
import logging
|
4 |
+
import math
|
5 |
+
import os
|
6 |
+
import torch
|
7 |
+
import numpy as np
|
8 |
+
import torch.nn as nn
|
9 |
+
from torch.cuda.amp import GradScaler, autocast
|
10 |
+
from .common import NetworkFeatureAggregator, Preprocessing, MeanMapper, Aggregator, RescaleSegmentor, ForwardHook
|
11 |
+
import torch.nn.functional as F
|
12 |
+
from torch.utils.tensorboard import SummaryWriter
|
13 |
+
import torch.optim as optim
|
14 |
+
from .model import Discriminator, Projection, PatchMaker
|
15 |
+
|
16 |
+
LOGGER = logging.getLogger(__name__)
|
17 |
+
IMAGENET_MEAN = [0.485, 0.456, 0.406]
|
18 |
+
IMAGENET_STD = [0.229, 0.224, 0.225]
|
19 |
+
|
20 |
+
|
21 |
+
class TBWrapper:
|
22 |
+
def __init__(self, log_dir):
|
23 |
+
self.g_iter = 0
|
24 |
+
self.logger = SummaryWriter(log_dir=log_dir)
|
25 |
+
|
26 |
+
def step(self):
|
27 |
+
self.g_iter += 1
|
28 |
+
|
29 |
+
def log(self, tag, value, step):
|
30 |
+
self.logger.add_scalar(tag, value, step)
|
31 |
+
|
32 |
+
|
33 |
+
class GLASS(torch.nn.Module):
|
34 |
+
def __init__(self, device):
|
35 |
+
super(GLASS, self).__init__()
|
36 |
+
self.device = device
|
37 |
+
|
38 |
+
def load(
|
39 |
+
self,
|
40 |
+
backbone,
|
41 |
+
layers_to_extract_from,
|
42 |
+
device,
|
43 |
+
input_shape,
|
44 |
+
pretrain_embed_dimension,
|
45 |
+
target_embed_dimension,
|
46 |
+
patchsize=3,
|
47 |
+
patchstride=1,
|
48 |
+
meta_epochs=640,
|
49 |
+
eval_epochs=1,
|
50 |
+
dsc_layers=2,
|
51 |
+
dsc_hidden=1024,
|
52 |
+
dsc_margin=0.5,
|
53 |
+
train_backbone=False, # Changed to be set externally
|
54 |
+
pre_proj=1,
|
55 |
+
mining=1,
|
56 |
+
noise=0.015,
|
57 |
+
radius=0.75,
|
58 |
+
p=0.5,
|
59 |
+
lr=0.0001,
|
60 |
+
svd=0,
|
61 |
+
step=20,
|
62 |
+
limit=392,
|
63 |
+
**kwargs,
|
64 |
+
):
|
65 |
+
|
66 |
+
self.backbone = backbone.to(device)
|
67 |
+
self.layers_to_extract_from = layers_to_extract_from
|
68 |
+
self.input_shape = input_shape
|
69 |
+
self.device = device
|
70 |
+
|
71 |
+
self.forward_modules = torch.nn.ModuleDict({})
|
72 |
+
feature_aggregator = NetworkFeatureAggregator(
|
73 |
+
self.backbone, self.layers_to_extract_from, self.device, train_backbone
|
74 |
+
)
|
75 |
+
feature_dimensions = feature_aggregator.feature_dimensions(input_shape)
|
76 |
+
self.forward_modules["feature_aggregator"] = feature_aggregator
|
77 |
+
|
78 |
+
preprocessing = Preprocessing(feature_dimensions, pretrain_embed_dimension)
|
79 |
+
self.forward_modules["preprocessing"] = preprocessing
|
80 |
+
self.target_embed_dimension = target_embed_dimension
|
81 |
+
preadapt_aggregator = Aggregator(target_dim=target_embed_dimension)
|
82 |
+
preadapt_aggregator.to(self.device)
|
83 |
+
self.forward_modules["preadapt_aggregator"] = preadapt_aggregator
|
84 |
+
|
85 |
+
self.meta_epochs = meta_epochs
|
86 |
+
self.lr = lr
|
87 |
+
self.train_backbone = train_backbone
|
88 |
+
if self.train_backbone:
|
89 |
+
self.backbone_opt = torch.optim.AdamW(self.forward_modules["feature_aggregator"].backbone.parameters(), lr)
|
90 |
+
|
91 |
+
self.pre_proj = pre_proj
|
92 |
+
if self.pre_proj > 0:
|
93 |
+
self.pre_projection = Projection(self.target_embed_dimension, self.target_embed_dimension, pre_proj)
|
94 |
+
self.pre_projection.to(self.device)
|
95 |
+
self.proj_opt = torch.optim.Adam(self.pre_projection.parameters(), lr, weight_decay=1e-5)
|
96 |
+
|
97 |
+
self.eval_epochs = eval_epochs
|
98 |
+
self.dsc_layers = dsc_layers
|
99 |
+
self.dsc_hidden = dsc_hidden
|
100 |
+
self.discriminator = Discriminator(self.target_embed_dimension, n_layers=dsc_layers, hidden=dsc_hidden)
|
101 |
+
self.discriminator.to(self.device)
|
102 |
+
self.dsc_opt = torch.optim.AdamW(self.discriminator.parameters(), lr=lr * 2)
|
103 |
+
self.dsc_margin = dsc_margin
|
104 |
+
|
105 |
+
self.c = torch.tensor(0)
|
106 |
+
self.c_ = torch.tensor(0)
|
107 |
+
self.p = p
|
108 |
+
self.radius = radius
|
109 |
+
self.mining = mining
|
110 |
+
self.noise = noise
|
111 |
+
self.svd = svd
|
112 |
+
self.step = step
|
113 |
+
self.limit = limit
|
114 |
+
self.distribution = 0
|
115 |
+
|
116 |
+
# Replace FocalLoss with MSELoss
|
117 |
+
self.loss_fn = nn.MSELoss()
|
118 |
+
|
119 |
+
self.patch_maker = PatchMaker(patchsize, stride=patchstride)
|
120 |
+
self.anomaly_segmentor = RescaleSegmentor(device=self.device, target_size=input_shape[-2:])
|
121 |
+
self.model_dir = ""
|
122 |
+
self.dataset_name = ""
|
123 |
+
self.logger = None
|
124 |
+
|
125 |
+
def set_model_dir(self, model_dir, dataset_name):
|
126 |
+
self.model_dir = model_dir
|
127 |
+
os.makedirs(self.model_dir, exist_ok=True)
|
128 |
+
self.ckpt_dir = os.path.join(self.model_dir, dataset_name)
|
129 |
+
os.makedirs(self.ckpt_dir, exist_ok=True)
|
130 |
+
self.tb_dir = os.path.join(self.ckpt_dir, "tb")
|
131 |
+
os.makedirs(self.tb_dir, exist_ok=True)
|
132 |
+
self.logger = TBWrapper(self.tb_dir)
|
133 |
+
|
134 |
+
def _embed(self, images, detach=True, provide_patch_shapes=False, evaluation=False):
|
135 |
+
"""Returns feature embeddings for images."""
|
136 |
+
images = images.float() # Ensure input tensor is float32
|
137 |
+
if not evaluation and self.train_backbone:
|
138 |
+
self.forward_modules["feature_aggregator"].train()
|
139 |
+
features = self.forward_modules["feature_aggregator"](images, eval=evaluation)
|
140 |
+
else:
|
141 |
+
self.forward_modules["feature_aggregator"].eval()
|
142 |
+
with torch.no_grad():
|
143 |
+
features = self.forward_modules["feature_aggregator"](images)
|
144 |
+
|
145 |
+
features = [features[layer] for layer in self.layers_to_extract_from]
|
146 |
+
|
147 |
+
for i, feat in enumerate(features):
|
148 |
+
if len(feat.shape) == 3:
|
149 |
+
B, L, C = feat.shape
|
150 |
+
sqrt_L = int(math.sqrt(L))
|
151 |
+
if sqrt_L * sqrt_L != L:
|
152 |
+
raise ValueError(f"Layer {self.layers_to_extract_from[i]} output has non-square spatial dimensions: {feat.shape}")
|
153 |
+
features[i] = feat.reshape(B, sqrt_L, sqrt_L, C).permute(0, 3, 1, 2)
|
154 |
+
# Debug statement
|
155 |
+
assert features[i].requires_grad, f"Feature {i} from layer {self.layers_to_extract_from[i]} does not require grad."
|
156 |
+
|
157 |
+
features = [self.patch_maker.patchify(x, return_spatial_info=True) for x in features]
|
158 |
+
patch_shapes = [x[1] for x in features]
|
159 |
+
patch_features = [x[0] for x in features]
|
160 |
+
ref_num_patches = patch_shapes[0]
|
161 |
+
|
162 |
+
for i in range(1, len(patch_features)):
|
163 |
+
_features = patch_features[i]
|
164 |
+
patch_dims = patch_shapes[i]
|
165 |
+
|
166 |
+
_features = _features.reshape(
|
167 |
+
_features.shape[0], patch_dims[0], patch_dims[1], *_features.shape[2:]
|
168 |
+
)
|
169 |
+
_features = _features.permute(0, 3, 4, 5, 1, 2)
|
170 |
+
perm_base_shape = _features.shape
|
171 |
+
_features = _features.reshape(-1, *_features.shape[-2:])
|
172 |
+
_features = F.interpolate(
|
173 |
+
_features.unsqueeze(1),
|
174 |
+
size=(ref_num_patches[0], ref_num_patches[1]),
|
175 |
+
mode="bilinear",
|
176 |
+
align_corners=False,
|
177 |
+
)
|
178 |
+
_features = _features.squeeze(1)
|
179 |
+
_features = _features.reshape(
|
180 |
+
*perm_base_shape[:-2], ref_num_patches[0], ref_num_patches[1]
|
181 |
+
)
|
182 |
+
_features = _features.permute(0, 4, 5, 1, 2, 3)
|
183 |
+
_features = _features.reshape(len(_features), -1, *_features.shape[-3:])
|
184 |
+
patch_features[i] = _features
|
185 |
+
|
186 |
+
patch_features = [x.reshape(-1, *x.shape[-3:]) for x in patch_features]
|
187 |
+
patch_features = self.forward_modules["preprocessing"](patch_features)
|
188 |
+
patch_features = self.forward_modules["preadapt_aggregator"](patch_features)
|
189 |
+
|
190 |
+
return patch_features, patch_shapes
|
191 |
+
|
192 |
+
def trainer(self, training_data, val_data, name):
|
193 |
+
"""
|
194 |
+
Training loop for the GLASS model.
|
195 |
+
|
196 |
+
Args:
|
197 |
+
training_data (DataLoader): DataLoader for the training dataset.
|
198 |
+
val_data (DataLoader): DataLoader for the validation dataset.
|
199 |
+
name (str): Name identifier for the training run.
|
200 |
+
"""
|
201 |
+
self.train()
|
202 |
+
self.discriminator.train()
|
203 |
+
|
204 |
+
# Initialize optimizers
|
205 |
+
optimizer = optim.AdamW(self.forward_modules.parameters(), lr=self.lr)
|
206 |
+
optimizer_d = optim.AdamW(self.discriminator.parameters(), lr=self.lr * 2)
|
207 |
+
|
208 |
+
# Initialize loss functions
|
209 |
+
criterion_d = nn.BCEWithLogitsLoss()
|
210 |
+
|
211 |
+
# Initialize separate AMP scalers
|
212 |
+
scaler_main = GradScaler()
|
213 |
+
scaler_dsc = GradScaler()
|
214 |
+
|
215 |
+
# Initialize TensorBoard writer
|
216 |
+
if self.logger is not None:
|
217 |
+
tb_writer = self.logger
|
218 |
+
else:
|
219 |
+
tb_writer = SummaryWriter()
|
220 |
+
|
221 |
+
best_auroc = 0.0
|
222 |
+
best_model_path = os.path.join(self.model_dir, f"best_model_{name}.pth")
|
223 |
+
|
224 |
+
for epoch in range(1, self.meta_epochs + 1):
|
225 |
+
LOGGER.info(f"Epoch [{epoch}/{self.meta_epochs}]")
|
226 |
+
epoch_loss = 0.0
|
227 |
+
epoch_loss_d = 0.0
|
228 |
+
for batch_idx, batch in enumerate(training_data):
|
229 |
+
images = batch['image'].to(self.device).float() # [B, 3, H, W]
|
230 |
+
aug_images = batch['aug'].to(self.device).float() # [B, 3, H, W]
|
231 |
+
masks_s = batch['mask_s'].to(self.device).float() # [B, H, W]
|
232 |
+
masks_gt = batch['mask_gt'].to(self.device).float() # [B, 1, H, W]
|
233 |
+
|
234 |
+
optimizer.zero_grad()
|
235 |
+
optimizer_d.zero_grad()
|
236 |
+
|
237 |
+
# ----- Train Main Model -----
|
238 |
+
with autocast():
|
239 |
+
# Forward pass
|
240 |
+
embeddings, _ = self._embed(images) # [B*N_patches, D]
|
241 |
+
aug_embeddings, _ = self._embed(aug_images) # [B*N_patches, D]
|
242 |
+
|
243 |
+
# Aggregate embeddings to [B, D] by averaging over patches
|
244 |
+
B = images.size(0)
|
245 |
+
N_patches = embeddings.size(0) // B
|
246 |
+
assert embeddings.size(
|
247 |
+
0) == B * N_patches, "Embeddings cannot be evenly divided into the batch size."
|
248 |
+
embeddings = embeddings.view(B, N_patches, -1).mean(dim=1) # [B, D]
|
249 |
+
aug_embeddings = aug_embeddings.view(B, N_patches, -1).mean(dim=1) # [B, D]
|
250 |
+
|
251 |
+
# Debug tensor properties
|
252 |
+
assert embeddings.requires_grad, "Embeddings do not require grad!"
|
253 |
+
assert aug_embeddings.requires_grad, "Augmented embeddings do not require grad!"
|
254 |
+
assert embeddings.shape[0] == images.size(
|
255 |
+
0), "Aggregated embeddings batch size does not match input batch size."
|
256 |
+
|
257 |
+
# Compute reconstruction or similarity loss
|
258 |
+
loss = self.loss_fn(embeddings, aug_embeddings)
|
259 |
+
|
260 |
+
# Backward pass with AMP for main model
|
261 |
+
scaler_main.scale(loss).backward()
|
262 |
+
scaler_main.step(optimizer)
|
263 |
+
scaler_main.update()
|
264 |
+
|
265 |
+
epoch_loss += loss.item()
|
266 |
+
|
267 |
+
# ----- Train Discriminator -----
|
268 |
+
with autocast():
|
269 |
+
# Detach embeddings to prevent gradients flowing back to the main model
|
270 |
+
embeddings_detached = embeddings.detach()
|
271 |
+
aug_embeddings_detached = aug_embeddings.detach()
|
272 |
+
|
273 |
+
# Discriminator forward pass
|
274 |
+
outputs_real = self.discriminator(embeddings_detached) # [B, 1]
|
275 |
+
outputs_fake = self.discriminator(aug_embeddings_detached) # [B, 1]
|
276 |
+
|
277 |
+
# Create labels
|
278 |
+
real_labels = torch.ones(outputs_real.size(0), 1).to(self.device) # [B, 1]
|
279 |
+
fake_labels = torch.zeros(outputs_fake.size(0), 1).to(self.device) # [B, 1]
|
280 |
+
|
281 |
+
# Compute discriminator loss
|
282 |
+
loss_real = criterion_d(outputs_real, real_labels)
|
283 |
+
loss_fake = criterion_d(outputs_fake, fake_labels)
|
284 |
+
loss_d = loss_real + loss_fake
|
285 |
+
|
286 |
+
# Backward pass with AMP for discriminator
|
287 |
+
scaler_dsc.scale(loss_d).backward()
|
288 |
+
scaler_dsc.step(optimizer_d)
|
289 |
+
scaler_dsc.update()
|
290 |
+
|
291 |
+
epoch_loss_d += loss_d.item()
|
292 |
+
|
293 |
+
if batch_idx % 100 == 0:
|
294 |
+
LOGGER.info(f"Batch [{batch_idx}/{len(training_data)}] "
|
295 |
+
f"Loss: {loss.item():.4f} Loss_D: {loss_d.item():.4f}")
|
296 |
+
|
297 |
+
avg_epoch_loss = epoch_loss / len(training_data)
|
298 |
+
avg_epoch_loss_d = epoch_loss_d / len(training_data)
|
299 |
+
|
300 |
+
LOGGER.info(f"Epoch [{epoch}/{self.meta_epochs}] "
|
301 |
+
f"Average Loss: {avg_epoch_loss:.4f} "
|
302 |
+
f"Average Loss_D: {avg_epoch_loss_d:.4f}")
|
303 |
+
|
304 |
+
# Log to TensorBoard
|
305 |
+
tb_writer.log("Train/Loss", avg_epoch_loss, epoch)
|
306 |
+
tb_writer.log("Train/Loss_D", avg_epoch_loss_d, epoch)
|
307 |
+
|
308 |
+
# Validation
|
309 |
+
if epoch % self.eval_epochs == 0:
|
310 |
+
auroc = self.tester(val_data, name)
|
311 |
+
LOGGER.info(f"Validation AUROC after Epoch [{epoch}]: {auroc:.4f}")
|
312 |
+
tb_writer.log("Validation/AUROC", auroc, epoch)
|
313 |
+
|
314 |
+
# Save the best model
|
315 |
+
if auroc > best_auroc:
|
316 |
+
best_auroc = auroc
|
317 |
+
torch.save(self.state_dict(), best_model_path) # Save only state_dict
|
318 |
+
LOGGER.info(f"Best model saved at Epoch [{epoch}] with AUROC: {auroc:.4f}")
|
319 |
+
|
320 |
+
LOGGER.info(f"Training completed. Best AUROC: {best_auroc:.4f}")
|
321 |
+
tb_writer.close()
|
322 |
+
|
323 |
+
def tester(self, test_data, name):
|
324 |
+
"""
|
325 |
+
Evaluation loop for the GLASS model.
|
326 |
+
|
327 |
+
Args:
|
328 |
+
test_data (DataLoader): DataLoader for the test dataset.
|
329 |
+
name (str): Name identifier for the evaluation run.
|
330 |
+
|
331 |
+
Returns:
|
332 |
+
float: AUROC score on the test dataset.
|
333 |
+
"""
|
334 |
+
self.eval()
|
335 |
+
self.discriminator.eval()
|
336 |
+
all_scores = []
|
337 |
+
all_labels = []
|
338 |
+
|
339 |
+
with torch.no_grad():
|
340 |
+
for batch_idx, batch in enumerate(test_data):
|
341 |
+
images = batch['image'].to(self.device).float() # [B, 3, H, W]
|
342 |
+
masks_gt = batch['mask_gt'].to(self.device).float() # [B, 1, H, W]
|
343 |
+
labels = batch['is_anomaly'].cpu().numpy() # [B]
|
344 |
+
|
345 |
+
# Forward pass
|
346 |
+
embeddings, _ = self._embed(images, evaluation=True) # [B*N_patches, D]
|
347 |
+
B = images.size(0)
|
348 |
+
N_patches = embeddings.size(0) // B
|
349 |
+
embeddings = embeddings.view(B, N_patches, -1).mean(dim=1) # [B, D]
|
350 |
+
anomaly_scores = self.discriminator(embeddings).cpu().numpy().flatten() # [B]
|
351 |
+
|
352 |
+
all_scores.extend(anomaly_scores.tolist())
|
353 |
+
all_labels.extend(labels.tolist())
|
354 |
+
|
355 |
+
# Compute AUROC
|
356 |
+
from sklearn.metrics import roc_auc_score
|
357 |
+
auroc = roc_auc_score(all_labels, all_scores)
|
358 |
+
return auroc
|
359 |
+
|
360 |
+
def _evaluate(self, images, scores, segmentations, labels_gt, masks_gt, name, path='training'):
|
361 |
+
# Implementation of evaluation metrics
|
362 |
+
pass
|
363 |
+
|
364 |
+
def predict(self, test_dataloader):
|
365 |
+
"""This function provides anomaly scores/maps for full dataloaders."""
|
366 |
+
# Implementation of prediction logic
|
367 |
+
pass
|
368 |
+
|
369 |
+
def _predict(self, img):
|
370 |
+
"""Infer score and mask for a batch of images."""
|
371 |
+
# Implementation of individual prediction logic
|
372 |
+
pass
|
models/model.py
ADDED
@@ -0,0 +1,101 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# models/model.py
|
2 |
+
|
3 |
+
import torch
|
4 |
+
|
5 |
+
def init_weight(m):
|
6 |
+
if isinstance(m, torch.nn.Linear):
|
7 |
+
torch.nn.init.xavier_normal_(m.weight)
|
8 |
+
if isinstance(m, torch.nn.BatchNorm2d):
|
9 |
+
m.weight.data.normal_(1.0, 0.02)
|
10 |
+
m.bias.data.fill_(0)
|
11 |
+
elif isinstance(m, torch.nn.Conv2d):
|
12 |
+
m.weight.data.normal_(0.0, 0.02)
|
13 |
+
|
14 |
+
|
15 |
+
class Discriminator(torch.nn.Module):
|
16 |
+
def __init__(self, in_planes, n_layers=2, hidden=None):
|
17 |
+
super(Discriminator, self).__init__()
|
18 |
+
|
19 |
+
_hidden = in_planes if hidden is None else hidden
|
20 |
+
self.body = torch.nn.Sequential()
|
21 |
+
for i in range(n_layers - 1):
|
22 |
+
_in = in_planes if i == 0 else _hidden
|
23 |
+
_hidden = int(_hidden // 1.5) if hidden is None else hidden
|
24 |
+
self.body.add_module('block%d' % (i + 1),
|
25 |
+
torch.nn.Sequential(
|
26 |
+
torch.nn.Linear(_in, _hidden),
|
27 |
+
torch.nn.BatchNorm1d(_hidden),
|
28 |
+
torch.nn.LeakyReLU(0.2)
|
29 |
+
))
|
30 |
+
self.tail = torch.nn.Sequential(
|
31 |
+
torch.nn.Linear(_hidden, 1, bias=False),
|
32 |
+
torch.nn.Sigmoid()
|
33 |
+
)
|
34 |
+
self.apply(init_weight)
|
35 |
+
|
36 |
+
def forward(self, x):
|
37 |
+
x = self.body(x)
|
38 |
+
x = self.tail(x)
|
39 |
+
return x
|
40 |
+
|
41 |
+
|
42 |
+
class Projection(torch.nn.Module):
|
43 |
+
def __init__(self, in_planes, out_planes=None, n_layers=1, layer_type=0):
|
44 |
+
super(Projection, self).__init__()
|
45 |
+
|
46 |
+
if out_planes is None:
|
47 |
+
out_planes = in_planes
|
48 |
+
self.layers = torch.nn.Sequential()
|
49 |
+
_in = None
|
50 |
+
_out = None
|
51 |
+
for i in range(n_layers):
|
52 |
+
_in = in_planes if i == 0 else _out
|
53 |
+
_out = out_planes
|
54 |
+
self.layers.add_module(f"{i}fc", torch.nn.Linear(_in, _out))
|
55 |
+
if i < n_layers - 1:
|
56 |
+
if layer_type > 1:
|
57 |
+
self.layers.add_module(f"{i}relu", torch.nn.LeakyReLU(.2))
|
58 |
+
self.apply(init_weight)
|
59 |
+
|
60 |
+
def forward(self, x):
|
61 |
+
x = self.layers(x)
|
62 |
+
return x
|
63 |
+
|
64 |
+
|
65 |
+
class PatchMaker:
|
66 |
+
def __init__(self, patchsize, top_k=0, stride=None):
|
67 |
+
self.patchsize = patchsize
|
68 |
+
self.stride = stride
|
69 |
+
self.top_k = top_k
|
70 |
+
|
71 |
+
def patchify(self, features, return_spatial_info=False):
|
72 |
+
"""Convert a tensor into a tensor of respective patches.
|
73 |
+
Args:
|
74 |
+
x: [torch.Tensor, bs x c x w x h]
|
75 |
+
Returns:
|
76 |
+
x: [torch.Tensor, bs * w//stride * h//stride, c, patchsize,
|
77 |
+
patchsize]
|
78 |
+
"""
|
79 |
+
padding = int((self.patchsize - 1) / 2)
|
80 |
+
unfolder = torch.nn.Unfold(kernel_size=self.patchsize, stride=self.stride, padding=padding, dilation=1)
|
81 |
+
unfolded_features = unfolder(features)
|
82 |
+
number_of_total_patches = []
|
83 |
+
for s in features.shape[-2:]:
|
84 |
+
n_patches = (s + 2 * padding - 1 * (self.patchsize - 1) - 1) / self.stride + 1
|
85 |
+
number_of_total_patches.append(int(n_patches))
|
86 |
+
unfolded_features = unfolded_features.reshape(
|
87 |
+
*features.shape[:2], self.patchsize, self.patchsize, -1
|
88 |
+
)
|
89 |
+
unfolded_features = unfolded_features.permute(0, 4, 1, 2, 3)
|
90 |
+
|
91 |
+
if return_spatial_info:
|
92 |
+
return unfolded_features, number_of_total_patches
|
93 |
+
return unfolded_features
|
94 |
+
|
95 |
+
def unpatch_scores(self, x, batchsize):
|
96 |
+
return x.reshape(batchsize, -1, *x.shape[1:])
|
97 |
+
|
98 |
+
def score(self, x):
|
99 |
+
x = x[:, :, 0]
|
100 |
+
x = torch.max(x, dim=1).values
|
101 |
+
return x
|
requirements.txt
ADDED
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
torch==2.4.1
|
2 |
+
torchvision==0.16.1
|
3 |
+
numpy==1.23.5
|
4 |
+
Pillow==9.4.0
|
5 |
+
tqdm==4.65.0
|
6 |
+
scikit-image==0.20.0
|
7 |
+
scikit-learn==1.2.2
|
8 |
+
scipy==1.11.4
|
run.sh
ADDED
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
1 |
+
#!/bin/sh
|
2 |
+
python3 main.py
|
runner.py
ADDED
@@ -0,0 +1,37 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# -----------------------------------------------------------------------------
|
2 |
+
# This Python script is the primary entry point called by our judge. It runs
|
3 |
+
# your code to generate anomaly scores, then evaluates those scores to produce
|
4 |
+
# the final results.
|
5 |
+
# -----------------------------------------------------------------------------
|
6 |
+
|
7 |
+
import subprocess
|
8 |
+
|
9 |
+
# Step 1: Generate anomaly scores
|
10 |
+
subprocess.run(["./run.sh"], check=True)
|
11 |
+
|
12 |
+
# Step 2: Evaluate the generated scores
|
13 |
+
subprocess.run(
|
14 |
+
[
|
15 |
+
"python3",
|
16 |
+
"evaluation/eval_main.py",
|
17 |
+
"--device",
|
18 |
+
"0",
|
19 |
+
"--data_path",
|
20 |
+
"./data/",
|
21 |
+
"--dataset_name",
|
22 |
+
"rayan_dataset",
|
23 |
+
"--class_name",
|
24 |
+
"all",
|
25 |
+
"--output_dir",
|
26 |
+
"./output",
|
27 |
+
"--output_scores_dir",
|
28 |
+
"./output_scores",
|
29 |
+
"--save_csv",
|
30 |
+
"True",
|
31 |
+
"--save_json",
|
32 |
+
"True",
|
33 |
+
"--class_name_mapping_dir",
|
34 |
+
"./evaluation/class_name_mapping.json",
|
35 |
+
],
|
36 |
+
check=True,
|
37 |
+
)
|
utils/dump_scores.py
ADDED
@@ -0,0 +1,34 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# utils/dump_scores.py
|
2 |
+
|
3 |
+
import os
|
4 |
+
import json
|
5 |
+
from pathlib import Path
|
6 |
+
|
7 |
+
class DumpScores:
|
8 |
+
def __init__(self, output_dir):
|
9 |
+
self.output_dir = output_dir
|
10 |
+
os.makedirs(self.output_dir, exist_ok=True)
|
11 |
+
|
12 |
+
def save_scores(self, image_paths, img_level_scores, pix_level_scores):
|
13 |
+
for img_path, img_score, pix_score in zip(image_paths, img_level_scores, pix_level_scores):
|
14 |
+
# Determine the relative path to maintain directory structure
|
15 |
+
relative_path = os.path.relpath(img_path, "./data")
|
16 |
+
relative_dir = os.path.dirname(relative_path)
|
17 |
+
output_dir = os.path.join(self.output_dir, relative_dir)
|
18 |
+
os.makedirs(output_dir, exist_ok=True)
|
19 |
+
|
20 |
+
# Get the image filename without extension
|
21 |
+
img_name = Path(img_path).stem
|
22 |
+
|
23 |
+
# Create the JSON structure
|
24 |
+
score_data = {
|
25 |
+
"img_level_score": img_score,
|
26 |
+
"pix_level_score": pix_score.tolist() # Convert numpy array to list for JSON serialization
|
27 |
+
}
|
28 |
+
|
29 |
+
# Define the output JSON file path
|
30 |
+
json_path = os.path.join(output_dir, f"{img_name}_scores.json")
|
31 |
+
|
32 |
+
# Save the JSON file
|
33 |
+
with open(json_path, "w") as f:
|
34 |
+
json.dump(score_data, f, indent=4)
|
utils/feature_extractor.py
ADDED
@@ -0,0 +1,18 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# utils/feature_extractor.py
|
2 |
+
|
3 |
+
import torch
|
4 |
+
import torch.nn as nn
|
5 |
+
from torchvision import models
|
6 |
+
|
7 |
+
class FeatureExtractor(nn.Module):
|
8 |
+
def __init__(self, backbone='resnet50'):
|
9 |
+
super(FeatureExtractor, self).__init__()
|
10 |
+
if backbone == 'resnet50':
|
11 |
+
self.model = models.resnet50(pretrained=True)
|
12 |
+
# Remove the final fully connected layer
|
13 |
+
self.features = nn.Sequential(*list(self.model.children())[:-2])
|
14 |
+
else:
|
15 |
+
raise NotImplementedError(f"Backbone {backbone} is not implemented.")
|
16 |
+
|
17 |
+
def forward(self, x):
|
18 |
+
return self.features(x)
|