hasibzunair commited on
Commit
1803579
1 Parent(s): 20deb15
This view is limited to 50 files because it contains too many changes.   See raw diff
.DS_Store ADDED
Binary file (6.15 kB). View file
 
README 2.md ADDED
@@ -0,0 +1,169 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Peekaboo
2
+
3
+ **Concordia University**
4
+
5
+ Hasib Zunair, A. Ben Hamza
6
+
7
+ [[`Paper`](https://arxiv.org/abs/2407.17628)] [[`Project`](https://hasibzunair.github.io/peekaboo/)] [[`Demo`](#4-demo)] [[`BibTeX`](#5-citation)]
8
+
9
+ This is official code for our **BMVC 2024 paper**:<br>
10
+ [PEEKABOO: Hiding Parts of an Image for Unsupervised Object Localization](Link)
11
+ <br>
12
+
13
+ ![MSL Design](./media/figure.jpg)
14
+
15
+ We aim to explicitly model contextual relationship among pixels through image masking for unsupervised object localization. In a self-supervised procedure (i.e. pretext task) without any additional training (i.e. downstream task), context-based representation learning is done at both the pixel-level by making predictions on masked images and at shape-level by matching the predictions of the masked input to the unmasked one.
16
+
17
+ ## 1. Specification of dependencies
18
+
19
+ This code requires Python 3.8 and CUDA 11.2. Clone the project repository, then create and activate the following conda envrionment.
20
+
21
+ ```bash
22
+ # clone repo
23
+ git clone https://github.com/hasibzunair/peekaboo
24
+ cd peekaboo
25
+ # create env
26
+ conda update conda
27
+ conda env create -f environment.yml
28
+ conda activate peekaboo
29
+ ```
30
+
31
+ Or, you can also create a fresh environment and install the project requirements inside that environment by:
32
+
33
+ ```bash
34
+ # clone repo
35
+ git clone https://github.com/hasibzunair/peekaboo
36
+ cd peekaboo
37
+ # create fresh env
38
+ conda create -n peekaboo python=3.8
39
+ conda activate peekaboo
40
+ # example of pytorch installation
41
+ pip install torch===1.8.1 torchvision==0.9.1 -f https://download.pytorch.org/whl/torch_stable.html
42
+ pip install pycocotools
43
+ # install dependencies
44
+ pip install -r requirements.txt
45
+ ```
46
+
47
+ And then, install [DINO](https://arxiv.org/pdf/2104.14294.pdf) using the following commands:
48
+
49
+ ```bash
50
+ git clone https://github.com/facebookresearch/dino.git
51
+ cd dino;
52
+ touch __init__.py
53
+ echo -e "import sys\nfrom os.path import dirname, join\nsys.path.insert(0, join(dirname(__file__), '.'))" >> __init__.py; cd ../;
54
+ ```
55
+
56
+ ## 2a. Training code
57
+
58
+ ### Dataset details
59
+
60
+ We train Peekaboo on only the images of [DUTS-TR](http://saliencydetection.net/duts/) dataset without any labels, since Peekaboo is self-supervised. Download it, then create a directory inside the project folder named `datasets_local` and put it there.
61
+
62
+ We evaluate on two tasks: unsupervised saliency detection and single object discovery. Since our method is used in an unsupervised setting, it does not require training or fine-tuning on the datasets we evaluate on.
63
+
64
+ #### Unsupervised Saliency Detection
65
+
66
+ We use the following datasets:
67
+
68
+ - [DUT-OMRON](http://saliencydetection.net/dut-omron/)
69
+ - [DUTS-TEST](http://saliencydetection.net/duts/)
70
+ - [ECSSD](https://www.cse.cuhk.edu.hk/leojia/projects/hsaliency/dataset.html)
71
+
72
+ Download the datasets and keep them in `datasets_local`.
73
+
74
+ #### Single Object Discovery
75
+
76
+ For single object discovery, we follow the framework used in [LOST](https://github.com/valeoai/LOST). Download the datasets and put them in the folder `datasets_local`.
77
+
78
+ - [VOC07](http://host.robots.ox.ac.uk/pascal/VOC/)
79
+ - [VOC12](http://host.robots.ox.ac.uk/pascal/VOC/)
80
+ - [COCO20k](https://cocodataset.org/#home)
81
+
82
+ Finally, download the masks of random streaks and holes of arbitrary shapes from [SCRIBBLES.zip](https://github.com/hasibzunair/masksup-segmentation/releases/download/v1.0/SCRIBBLES.zip) and put it inside `datasets` folder.
83
+
84
+ ### DUTS-TR training
85
+
86
+ ```bash
87
+ export DATASET_DIR=datasets_local # root directory training and evaluation datasets
88
+
89
+ python train.py --exp-name peekaboo --dataset-dir $DATASET_DIR
90
+ ```
91
+
92
+ See tensorboard logs by running: `tensorboard --logdir=outputs`.
93
+
94
+ ## 2b. Evaluation code
95
+
96
+ After training, the model checkpoint and logs are available in `peekaboo-DUTS-TR-vit_small8` in the `outputs` folder. Set the model path for evaluation.
97
+
98
+ ```bash
99
+ export MODEL="outputs/peekaboo-DUTS-TR-vit_small8/decoder_weights_niter500.pt"
100
+ ```
101
+
102
+ ### Unsupervised saliency detection eval
103
+
104
+ ```bash
105
+ # run evaluation
106
+ source evaluate_saliency.sh $MODEL $DATASET_DIR single
107
+ source evaluate_saliency.sh $MODEL $DATASET_DIR multi
108
+ ```
109
+
110
+ ### Single object discovery eval
111
+
112
+ ```bash
113
+ # run evalulation
114
+ source evaluate_uod.sh $MODEL $DATASET_DIR
115
+ ```
116
+
117
+ All experiments are conducted on a single NVIDIA 3080Ti GPU. For additional implementation details and results, please refer to the supplementary materials section in the paper.
118
+
119
+ ## 3. Pre-trained models
120
+
121
+ We provide pretrained models on [./data/weights/](./data/weights/) for reproducibility. Here are the main results of Peekaboo on single object discovery task. For results on unsupervised saliency detection task, we refer readers to our paper!
122
+
123
+ |Dataset | Backbone | CorLoc (%) | Download |
124
+ | ---------- | ------- | ------ | -------- |
125
+ | VOC07 | ViT-S/8 | 72.7 | [download](./data/weights/peekaboo_decoder_weights_niter500.pt) |
126
+ | VOC12 | ViT-S/8 | 75.9 | [download](./data/weights/peekaboo_decoder_weights_niter500.pt) |
127
+ | COCO20K | ViT-S/8 | 64.0 | [download](./data/weights/peekaboo_decoder_weights_niter500.pt) |
128
+
129
+ ## 4. Demo
130
+
131
+ We provide prediction demos of our models. The following applies and visualizes our method on a single image.
132
+
133
+ ```bash
134
+ # infer on one image
135
+ python demo.py
136
+ ```
137
+
138
+ ## 5. Citation
139
+
140
+ ```bibtex
141
+ @inproceedings{zunair2024peekaboo,
142
+ title={PEEKABOO: Hiding Parts of an Image for Unsupervised Object Localization},
143
+ author={Zunair, Hasib and Hamza, A Ben},
144
+ booktitle={Proc. British Machine Vision Conference},
145
+ year={2024}
146
+ }
147
+ ```
148
+
149
+ ## Project Notes
150
+
151
+ <details><summary>Click to view</summary>
152
+ <br>
153
+
154
+ **[Mar 18, 2024]** Infer on image folders.
155
+
156
+ ```python
157
+ # infer on folder of images
158
+ python visualize_outputs.py --model-weights outputs/msl_a1.5_b1_g1_reg4-MSL-DUTS-TR-vit_small8/decoder_weights_niter500.pt --img-folder ./datasets_local/DUTS-TR/DUTS-TR-Image/ --output-dir outputs/visualizations/msl_masks
159
+ ```
160
+
161
+ **[Nov 10, 2023]** Reproduced FOUND results.
162
+
163
+ **[Nov 10, 2023]** Added project notes section.
164
+
165
+ </details>
166
+
167
+ ## Acknowledgements
168
+
169
+ This repository was built on top of [FOUND](https://github.com/valeoai/FOUND), [SelfMask](https://github.com/NoelShin/selfmask), [TokenCut](https://github.com/YangtaoWANG95/TokenCut) and [LOST](https://github.com/valeoai/LOST). Consider acknowledging these projects.
__init__.py ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ import sys
2
+ from os.path import dirname, join
3
+
4
+ sys.path.insert(0, join(dirname(__file__), "."))
app.py ADDED
@@ -0,0 +1,119 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import torch
3
+ import argparse
4
+ import torch.nn as nn
5
+ import torch.nn.functional as F
6
+ import matplotlib.pyplot as plt
7
+ import gradio as gr
8
+ import codecs
9
+ import numpy as np
10
+ import cv2
11
+
12
+ from PIL import Image
13
+ from model import PeekabooModel
14
+ from misc import load_config
15
+ from torchvision import transforms as T
16
+
17
+ NORMALIZE = T.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))
18
+
19
+ if __name__ == "__main__":
20
+
21
+ def inference(img_path):
22
+ # Load the image
23
+ with open(img_path, "rb") as f:
24
+ img = Image.open(f)
25
+ img = img.convert("RGB")
26
+ img_np = np.array(img)
27
+
28
+ # Preprocess
29
+ t = T.Compose([T.ToTensor(), NORMALIZE])
30
+ img_t = t(img)[None, :, :, :]
31
+ inputs = img_t.to(device)
32
+
33
+ # Forward step
34
+ print(f"Start Peekaboo prediction.")
35
+ with torch.no_grad():
36
+ preds = model(inputs, for_eval=True)
37
+ print(f"Done Peekaboo prediction.")
38
+
39
+ sigmoid = nn.Sigmoid()
40
+ h, w = img_t.shape[-2:]
41
+ preds_up = F.interpolate(
42
+ preds, scale_factor=model.vit_patch_size, mode="bilinear", align_corners=False
43
+ )[..., :h, :w]
44
+ preds_up = (sigmoid(preds_up.detach()) > 0.5).squeeze(0).float()
45
+ preds_up = preds_up.cpu().squeeze().numpy()
46
+
47
+ # Overlay predicted mask with input image
48
+ preds_up_np = (preds_up / np.max(preds_up) * 255).astype(np.uint8)
49
+ preds_up_np_3d = np.stack([preds_up_np, preds_up_np, preds_up_np], axis=-1)
50
+ combined_image = cv2.addWeighted(img_np, 0.5, preds_up_np_3d, 0.5, 0)
51
+ print(f"Output shape is {combined_image.shape}")
52
+ return combined_image
53
+
54
+ parser = argparse.ArgumentParser(
55
+ description="Evaluation of Peekaboo",
56
+ formatter_class=argparse.ArgumentDefaultsHelpFormatter,
57
+ )
58
+
59
+ parser.add_argument(
60
+ "--img-path",
61
+ type=str,
62
+ default="data/examples/VOC_000030.jpg",
63
+ help="Image path.",
64
+ )
65
+ parser.add_argument(
66
+ "--model-weights",
67
+ type=str,
68
+ default="data/weights/peekaboo_decoder_weights_niter500.pt",
69
+ )
70
+ parser.add_argument(
71
+ "--config",
72
+ type=str,
73
+ default="configs/peekaboo_DUTS-TR.yaml",
74
+ )
75
+ parser.add_argument(
76
+ "--output-dir",
77
+ type=str,
78
+ default="outputs",
79
+ )
80
+ args = parser.parse_args()
81
+
82
+ # Configuration
83
+ config, _ = load_config(args.config)
84
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
85
+
86
+ # Load the model
87
+ model = PeekabooModel(
88
+ vit_model=config.model["pre_training"],
89
+ vit_arch=config.model["arch"],
90
+ vit_patch_size=config.model["patch_size"],
91
+ enc_type_feats=config.peekaboo["feats"],
92
+ )
93
+ # Load weights
94
+ model.decoder_load_weights(args.model_weights)
95
+ model.eval()
96
+ print(f"Model {args.model_weights} loaded correctly.")
97
+
98
+ # App
99
+ title = "PEEKABOO: Hiding Parts of an Image for Unsupervised Object Localization"
100
+ description = codecs.open("./media/description.html", "r", "utf-8").read()
101
+ article = "<p style='text-align: center'><a href='https://arxiv.org/abs/2407.17628' target='_blank'>PEEKABOO: Hiding Parts of an Image for Unsupervised Object Localization</a> | <a href='https://github.com/hasibzunair/peekaboo' target='_blank'>Github</a></p>"
102
+
103
+ gr.Interface(
104
+ inference,
105
+ gr.inputs.Image(type="filepath", label="Input Image"),
106
+ gr.outputs.Image(type="numpy", label="Predicted Output"),
107
+ examples=[
108
+ "./data/examples/a.jpeg",
109
+ "./data/examples/b.jpeg",
110
+ "./data/examples/c.jpeg",
111
+ "./data/examples/d.jpeg",
112
+ "./data/examples/e.jpeg"
113
+ ],
114
+ title=title,
115
+ description=description,
116
+ article=article,
117
+ allow_flagging=False,
118
+ analytics_enabled=False,
119
+ ).launch(debug=True, enable_queue=True)
bilateral_solver.py ADDED
@@ -0,0 +1,218 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Code adapted from TokenCut: https://github.com/YangtaoWANG95/TokenCut
3
+ """
4
+
5
+ import PIL.Image as Image
6
+ import numpy as np
7
+ from scipy import ndimage
8
+ from scipy.sparse import diags, csr_matrix
9
+ from scipy.sparse.linalg import cg
10
+
11
+ RGB_TO_YUV = np.array(
12
+ [[0.299, 0.587, 0.114], [-0.168736, -0.331264, 0.5], [0.5, -0.418688, -0.081312]]
13
+ )
14
+ YUV_TO_RGB = np.array([[1.0, 0.0, 1.402], [1.0, -0.34414, -0.71414], [1.0, 1.772, 0.0]])
15
+ YUV_OFFSET = np.array([0, 128.0, 128.0]).reshape(1, 1, -1)
16
+ MAX_VAL = 255.0
17
+
18
+
19
+ def rgb2yuv(im):
20
+ return np.tensordot(im, RGB_TO_YUV, ([2], [1])) + YUV_OFFSET
21
+
22
+
23
+ def yuv2rgb(im):
24
+ return np.tensordot(im.astype(float) - YUV_OFFSET, YUV_TO_RGB, ([2], [1]))
25
+
26
+
27
+ def get_valid_idx(valid, candidates):
28
+ """Find which values are present in a list and where they are located"""
29
+ locs = np.searchsorted(valid, candidates)
30
+ # Handle edge case where the candidate is larger than all valid values
31
+ locs = np.clip(locs, 0, len(valid) - 1)
32
+ # Identify which values are actually present
33
+ valid_idx = np.flatnonzero(valid[locs] == candidates)
34
+ locs = locs[valid_idx]
35
+ return valid_idx, locs
36
+
37
+
38
+ class BilateralGrid(object):
39
+ def __init__(self, im, sigma_spatial=32, sigma_luma=8, sigma_chroma=8):
40
+ im_yuv = rgb2yuv(im)
41
+ # Compute 5-dimensional XYLUV bilateral-space coordinates
42
+ Iy, Ix = np.mgrid[: im.shape[0], : im.shape[1]]
43
+ x_coords = (Ix / sigma_spatial).astype(int)
44
+ y_coords = (Iy / sigma_spatial).astype(int)
45
+ luma_coords = (im_yuv[..., 0] / sigma_luma).astype(int)
46
+ chroma_coords = (im_yuv[..., 1:] / sigma_chroma).astype(int)
47
+ coords = np.dstack((x_coords, y_coords, luma_coords, chroma_coords))
48
+ coords_flat = coords.reshape(-1, coords.shape[-1])
49
+ self.npixels, self.dim = coords_flat.shape
50
+ # Hacky "hash vector" for coordinates,
51
+ # Requires all scaled coordinates be < MAX_VAL
52
+ self.hash_vec = MAX_VAL ** np.arange(self.dim)
53
+ # Construct S and B matrix
54
+ self._compute_factorization(coords_flat)
55
+
56
+ def _compute_factorization(self, coords_flat):
57
+ # Hash each coordinate in grid to a unique value
58
+ hashed_coords = self._hash_coords(coords_flat)
59
+ unique_hashes, unique_idx, idx = np.unique(
60
+ hashed_coords, return_index=True, return_inverse=True
61
+ )
62
+ # Identify unique set of vertices
63
+ unique_coords = coords_flat[unique_idx]
64
+ self.nvertices = len(unique_coords)
65
+ # Construct sparse splat matrix that maps from pixels to vertices
66
+ self.S = csr_matrix((np.ones(self.npixels), (idx, np.arange(self.npixels))))
67
+ # Construct sparse blur matrices.
68
+ # Note that these represent [1 0 1] blurs, excluding the central element
69
+ self.blurs = []
70
+ for d in range(self.dim):
71
+ blur = 0.0
72
+ for offset in (-1, 1):
73
+ offset_vec = np.zeros((1, self.dim))
74
+ offset_vec[:, d] = offset
75
+ neighbor_hash = self._hash_coords(unique_coords + offset_vec)
76
+ valid_coord, idx = get_valid_idx(unique_hashes, neighbor_hash)
77
+ blur = blur + csr_matrix(
78
+ (np.ones((len(valid_coord),)), (valid_coord, idx)),
79
+ shape=(self.nvertices, self.nvertices),
80
+ )
81
+ self.blurs.append(blur)
82
+
83
+ def _hash_coords(self, coord):
84
+ """Hacky function to turn a coordinate into a unique value"""
85
+ return np.dot(coord.reshape(-1, self.dim), self.hash_vec)
86
+
87
+ def splat(self, x):
88
+ return self.S.dot(x)
89
+
90
+ def slice(self, y):
91
+ return self.S.T.dot(y)
92
+
93
+ def blur(self, x):
94
+ """Blur a bilateral-space vector with a 1 2 1 kernel in each dimension"""
95
+ assert x.shape[0] == self.nvertices
96
+ out = 2 * self.dim * x
97
+ for blur in self.blurs:
98
+ out = out + blur.dot(x)
99
+ return out
100
+
101
+ def filter(self, x):
102
+ """Apply bilateral filter to an input x"""
103
+ return self.slice(self.blur(self.splat(x))) / self.slice(
104
+ self.blur(self.splat(np.ones_like(x)))
105
+ )
106
+
107
+
108
+ def bistochastize(grid, maxiter=10):
109
+ """Compute diagonal matrices to bistochastize a bilateral grid"""
110
+ m = grid.splat(np.ones(grid.npixels))
111
+ n = np.ones(grid.nvertices)
112
+ for i in range(maxiter):
113
+ n = np.sqrt(n * m / grid.blur(n))
114
+ # Correct m to satisfy the assumption of bistochastization regardless
115
+ # of how many iterations have been run.
116
+ m = n * grid.blur(n)
117
+ Dm = diags(m, 0)
118
+ Dn = diags(n, 0)
119
+ return Dn, Dm
120
+
121
+
122
+ class BilateralSolver(object):
123
+ def __init__(self, grid, params):
124
+ self.grid = grid
125
+ self.params = params
126
+ self.Dn, self.Dm = bistochastize(grid)
127
+
128
+ def solve(self, x, w):
129
+ # Check that w is a vector or a nx1 matrix
130
+ if w.ndim == 2:
131
+ assert w.shape[1] == 1
132
+ elif w.dim == 1:
133
+ w = w.reshape(w.shape[0], 1)
134
+ A_smooth = self.Dm - self.Dn.dot(self.grid.blur(self.Dn))
135
+ w_splat = self.grid.splat(w)
136
+ A_data = diags(w_splat[:, 0], 0)
137
+ A = self.params["lam"] * A_smooth + A_data
138
+ xw = x * w
139
+ b = self.grid.splat(xw)
140
+ # Use simple Jacobi preconditioner
141
+ A_diag = np.maximum(A.diagonal(), self.params["A_diag_min"])
142
+ M = diags(1 / A_diag, 0)
143
+ # Flat initialization
144
+ y0 = self.grid.splat(xw) / w_splat
145
+ yhat = np.empty_like(y0)
146
+ for d in range(x.shape[-1]):
147
+ yhat[..., d], info = cg(
148
+ A,
149
+ b[..., d],
150
+ x0=y0[..., d],
151
+ M=M,
152
+ maxiter=self.params["cg_maxiter"],
153
+ tol=self.params["cg_tol"],
154
+ )
155
+ xhat = self.grid.slice(yhat)
156
+ return xhat
157
+
158
+
159
+ def bilateral_solver_output(
160
+ img_pth,
161
+ target,
162
+ img=None,
163
+ sigma_spatial=24,
164
+ sigma_luma=4,
165
+ sigma_chroma=4,
166
+ get_all_cc=False,
167
+ ):
168
+ if img is None:
169
+ reference = np.array(Image.open(img_pth).convert("RGB"))
170
+ else:
171
+ reference = np.array(img)
172
+
173
+ h, w = target.shape
174
+ confidence = np.ones((h, w)) * 0.999
175
+
176
+ grid_params = {
177
+ "sigma_luma": sigma_luma, # Brightness bandwidth
178
+ "sigma_chroma": sigma_chroma, # Color bandwidth
179
+ "sigma_spatial": sigma_spatial, # Spatial bandwidth
180
+ }
181
+
182
+ bs_params = {
183
+ "lam": 256, # The strength of the smoothness parameter
184
+ "A_diag_min": 1e-5, # Clamp the diagonal of the A diagonal in the Jacobi preconditioner.
185
+ "cg_tol": 1e-5, # The tolerance on the convergence in PCG
186
+ "cg_maxiter": 25, # The number of PCG iterations
187
+ }
188
+
189
+ grid = BilateralGrid(reference, **grid_params)
190
+
191
+ t = target.reshape(-1, 1).astype(np.double)
192
+ c = confidence.reshape(-1, 1).astype(np.double)
193
+
194
+ # output solver, which is a soft value
195
+ output_solver = BilateralSolver(grid, bs_params).solve(t, c).reshape((h, w))
196
+
197
+ binary_solver = ndimage.binary_fill_holes(output_solver > 0.5)
198
+ labeled, nr_objects = ndimage.label(binary_solver)
199
+
200
+ nb_pixel = [np.sum(labeled == i) for i in range(nr_objects + 1)]
201
+ pixel_order = np.argsort(nb_pixel)
202
+
203
+ if get_all_cc:
204
+ # Remove known bakground
205
+ pixel_descending_order = pixel_order[::-1]
206
+ # Get all CC expect biggest one, may consider it as background, try and change here
207
+ binary_solver = (
208
+ (labeled[None, :, :] == pixel_descending_order[1:, None, None])
209
+ .astype(int)
210
+ .sum(0)
211
+ )
212
+ else:
213
+ try:
214
+ binary_solver = labeled == pixel_order[-2]
215
+ except:
216
+ binary_solver = np.ones((h, w), dtype=bool)
217
+
218
+ return output_solver, binary_solver
bkg_seg.py ADDED
@@ -0,0 +1,86 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2022 - Valeo Comfort and Driving Assistance - Oriane Siméoni @ valeo.ai
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import torch
16
+ import torch.nn.functional as F
17
+
18
+ from typing import Tuple
19
+
20
+
21
+ def compute_img_bkg_seg(
22
+ attentions,
23
+ feats,
24
+ featmap_dims,
25
+ th_bkg,
26
+ dim=64,
27
+ epsilon: float = 1e-10,
28
+ apply_weights: bool = True,
29
+ ) -> Tuple[torch.Tensor, float]:
30
+ """
31
+ inputs
32
+ - attentions [B, ]
33
+ """
34
+
35
+ w_featmap, h_featmap = featmap_dims
36
+
37
+ nb, nh, _ = attentions.shape[:3]
38
+ # we keep only the output patch attention
39
+ att = attentions[:, :, 0, 1:].reshape(nb, nh, -1)
40
+ att = att.reshape(nb, nh, w_featmap, h_featmap)
41
+
42
+ # -----------------------------------------------
43
+ # Inspired by CroW sparsity channel weighting of each head CroW, Kalantidis etal.
44
+ threshold = torch.mean(att.reshape(nb, -1), dim=1) # Find threshold per image
45
+ Q = torch.sum(
46
+ att.reshape(nb, nh, w_featmap * h_featmap) > threshold[:, None, None], axis=2
47
+ ) / (w_featmap * h_featmap)
48
+ beta = torch.log(torch.sum(Q + epsilon, dim=1)[:, None] / (Q + epsilon))
49
+
50
+ # Weight features based on attention sparsity
51
+ descs = feats[
52
+ :,
53
+ 1:,
54
+ ]
55
+ if apply_weights:
56
+ descs = (descs.reshape(nb, -1, nh, dim) * beta[:, None, :, None]).reshape(
57
+ nb, -1, nh * dim
58
+ )
59
+ else:
60
+ descs = (descs.reshape(nb, -1, nh, dim)).reshape(nb, -1, nh * dim)
61
+
62
+ # -----------------------------------------------
63
+ # Compute cosine-similarities
64
+ descs = F.normalize(descs, dim=-1, p=2)
65
+ cos_sim = torch.bmm(descs, descs.permute(0, 2, 1))
66
+
67
+ # -----------------------------------------------
68
+ # Find pixel with least amount of attention
69
+ if apply_weights:
70
+ att = att.reshape(nb, nh, w_featmap, h_featmap) * beta[:, :, None, None]
71
+ else:
72
+ att = att.reshape(nb, nh, w_featmap, h_featmap)
73
+ id_pixel_ref = torch.argmin(torch.sum(att, axis=1).reshape(nb, -1), dim=-1)
74
+
75
+ # -----------------------------------------------
76
+ # Mask of definitely background pixels: 1 on the background
77
+ cos_sim = cos_sim.reshape(nb, -1, w_featmap * h_featmap)
78
+
79
+ bkg_mask = (
80
+ cos_sim[torch.arange(cos_sim.size(0)), id_pixel_ref, :].reshape(
81
+ nb, w_featmap, h_featmap
82
+ )
83
+ > th_bkg
84
+ ) # mask to be used to remove background
85
+
86
+ return bkg_mask.float()
configs/peekaboo_DUTS-TR.yaml ADDED
@@ -0,0 +1,32 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ model:
2
+ arch: vit_small
3
+ patch_size: 8
4
+ pre_training: dino
5
+
6
+ peekaboo:
7
+ feats: "k"
8
+
9
+ training:
10
+ dataset: DUTS-TR
11
+ dataset_set: null
12
+
13
+ # Hyper params
14
+ seed: 0
15
+ max_iter: 500
16
+ nb_epochs: 3
17
+ batch_size: 50
18
+ lr0: 5e-2
19
+ step_lr_size: 50
20
+ step_lr_gamma: 0.95
21
+
22
+ # Augmentations
23
+ crop_size: 224
24
+ scale_range: [0.1, 3.0]
25
+ photometric_aug: gaussian_blur
26
+ proba_photometric_aug: 0.5
27
+ cropping_strategy: random_scale
28
+
29
+ evaluation:
30
+ type: saliency
31
+ datasets: [DUT-OMRON, ECSSD]
32
+ freq: 50
data/.DS_Store ADDED
Binary file (6.15 kB). View file
 
data/coco_20k_filenames.txt ADDED
The diff for this file is too large to render. See raw diff
 
data/examples/.DS_Store ADDED
Binary file (6.15 kB). View file
 
data/examples/VOC_000030.jpg ADDED
data/examples/a.jpeg ADDED
data/examples/b.jpeg ADDED
data/examples/c.jpeg ADDED
data/examples/d.jpeg ADDED
data/examples/e.jpeg ADDED
data/weights/peekaboo_decoder_weights_niter250.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:8621874b7459a940f2c584ef8d618c961eac407bc616ca7a76e3c90b745a61f7
3
+ size 2795
data/weights/peekaboo_decoder_weights_niter500.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:889f87ee21ea17a828d6065e3e187521989da1e94ebecc0f5988aaacb2a0c40f
3
+ size 2795
datasets/VOC.py ADDED
@@ -0,0 +1,82 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from typing import Optional, Tuple, Union, Dict, List
3
+
4
+ import cv2
5
+ from pycocotools.coco import COCO
6
+ import numpy as np
7
+ import torch
8
+ import torchvision
9
+ from PIL import Image, PngImagePlugin
10
+ from torch.utils.data import Dataset
11
+ from torchvision import transforms as T
12
+ from torchvision.transforms import ColorJitter, RandomApply, RandomGrayscale
13
+ from tqdm import tqdm
14
+
15
+ VOCDetectionMetadataType = Dict[str, Dict[str, Union[str, Dict[str, str], List[str]]]]
16
+
17
+
18
+ def get_voc_detection_gt(
19
+ metadata: VOCDetectionMetadataType, remove_hards: bool = False
20
+ ) -> Tuple[np.array, List[str]]:
21
+ objects = metadata["annotation"]["object"]
22
+ nb_obj = len(objects)
23
+
24
+ gt_bbxs = []
25
+ gt_clss = []
26
+ for object in range(nb_obj):
27
+ if remove_hards and (
28
+ objects[object]["truncated"] == "1" or objects[object]["difficult"] == "1"
29
+ ):
30
+ continue
31
+
32
+ gt_cls = objects[object]["name"]
33
+ gt_clss.append(gt_cls)
34
+ obj = objects[object]["bndbox"]
35
+ x1y1x2y2 = [
36
+ int(obj["xmin"]),
37
+ int(obj["ymin"]),
38
+ int(obj["xmax"]),
39
+ int(obj["ymax"]),
40
+ ]
41
+
42
+ # Original annotations are integers in the range [1, W or H]
43
+ # Assuming they mean 1-based pixel indices (inclusive),
44
+ # a box with annotation (xmin=1, xmax=W) covers the whole image.
45
+ # In coordinate space this is represented by (xmin=0, xmax=W)
46
+ x1y1x2y2[0] -= 1
47
+ x1y1x2y2[1] -= 1
48
+ gt_bbxs.append(x1y1x2y2)
49
+
50
+ return np.asarray(gt_bbxs), gt_clss
51
+
52
+
53
+ def create_gt_masks_if_voc(labels: PngImagePlugin.PngImageFile) -> Image.Image:
54
+ mask = np.array(labels)
55
+ mask_gt = (mask > 0).astype(float)
56
+ mask_gt = np.where(mask_gt != 0.0, 255, mask_gt)
57
+ mask_gt = Image.fromarray(np.uint8(mask_gt))
58
+ return mask_gt
59
+
60
+
61
+ def create_VOC_loader(img_dir, dataset_set, evaluation_type):
62
+ year = img_dir[-4:]
63
+ download = not os.path.exists(img_dir)
64
+ if evaluation_type == "uod":
65
+ loader = torchvision.datasets.VOCDetection(
66
+ img_dir,
67
+ year=year,
68
+ image_set=dataset_set,
69
+ transform=None,
70
+ download=download,
71
+ )
72
+ elif evaluation_type == "saliency":
73
+ loader = torchvision.datasets.VOCSegmentation(
74
+ img_dir,
75
+ year=year,
76
+ image_set=dataset_set,
77
+ transform=None,
78
+ download=download,
79
+ )
80
+ else:
81
+ raise ValueError(f"Not implemented for {evaluation_type}.")
82
+ return loader
datasets/__init__.py ADDED
File without changes
datasets/augmentations.py ADDED
@@ -0,0 +1,70 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Code borrowed from SelfMask: https://github.com/NoelShin/selfmask
3
+ """
4
+
5
+ import numpy as np
6
+ import torch
7
+ from PIL import Image
8
+ from typing import Optional, Tuple, Union
9
+ from torchvision.transforms import ColorJitter, RandomApply, RandomGrayscale
10
+
11
+ from datasets.utils import GaussianBlur
12
+ from datasets.geometric_transforms import (
13
+ random_scale,
14
+ random_crop,
15
+ random_hflip,
16
+ )
17
+
18
+
19
+ def geometric_augmentations(
20
+ image: Image.Image,
21
+ random_scale_range: Optional[Tuple[float, float]] = None,
22
+ random_crop_size: Optional[int] = None,
23
+ random_hflip_p: Optional[float] = None,
24
+ mask: Optional[Union[Image.Image, np.ndarray, torch.Tensor]] = None,
25
+ ignore_index: Optional[int] = None,
26
+ ) -> Tuple[Image.Image, torch.Tensor]:
27
+ """Note. image and mask are assumed to be of base size, thus share a spatial shape."""
28
+ if random_scale_range is not None:
29
+ image, mask = random_scale(
30
+ image=image, random_scale_range=random_scale_range, mask=mask
31
+ )
32
+
33
+ if random_crop_size is not None:
34
+ crop_size = (random_crop_size, random_crop_size)
35
+ fill = tuple(np.array(image).mean(axis=(0, 1)).astype(np.uint8).tolist())
36
+ image, offset = random_crop(image=image, crop_size=crop_size, fill=fill)
37
+
38
+ if mask is not None:
39
+ assert ignore_index is not None
40
+ mask = random_crop(
41
+ image=mask, crop_size=crop_size, fill=ignore_index, offset=offset
42
+ )[0]
43
+
44
+ if random_hflip_p is not None:
45
+ image, mask = random_hflip(image=image, p=random_hflip_p, mask=mask)
46
+ return image, mask
47
+
48
+
49
+ def photometric_augmentations(
50
+ image: Image.Image,
51
+ random_color_jitter: bool,
52
+ random_grayscale: bool,
53
+ random_gaussian_blur: bool,
54
+ proba_photometric_aug: float,
55
+ ) -> torch.Tensor:
56
+ if random_color_jitter:
57
+ color_jitter = ColorJitter(
58
+ brightness=0.8, contrast=0.8, saturation=0.8, hue=0.2
59
+ )
60
+ image = RandomApply([color_jitter], p=proba_photometric_aug)(image)
61
+
62
+ if random_grayscale:
63
+ image = RandomGrayscale(proba_photometric_aug)(image)
64
+
65
+ if random_gaussian_blur:
66
+ w, h = image.size
67
+ image = GaussianBlur(kernel_size=int((0.1 * min(w, h) // 2 * 2) + 1))(
68
+ image, proba_photometric_aug
69
+ )
70
+ return image
datasets/datasets.py ADDED
@@ -0,0 +1,476 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Code for Peekaboo
2
+ # Author: Hasib Zunair
3
+ # Modified from https://github.com/NoelShin/selfmask
4
+
5
+ """
6
+ Dataset functions for applying Normalized Cut.
7
+ """
8
+
9
+ import os
10
+ import glob
11
+ import random
12
+ from typing import Optional, Tuple, Union
13
+
14
+ from pycocotools.coco import COCO
15
+ import numpy as np
16
+ import torch
17
+ import torchvision
18
+ from PIL import Image
19
+ from torch.utils.data import Dataset
20
+ from torchvision import transforms as T
21
+
22
+ try:
23
+ from torchvision.transforms import InterpolationMode
24
+
25
+ BICUBIC = InterpolationMode.BICUBIC
26
+ except ImportError:
27
+ BICUBIC = Image.BICUBIC
28
+
29
+ from datasets.utils import unnormalize
30
+ from datasets.geometric_transforms import resize
31
+ from datasets.VOC import get_voc_detection_gt, create_gt_masks_if_voc, create_VOC_loader
32
+ from datasets.augmentations import geometric_augmentations, photometric_augmentations
33
+
34
+ from datasets.uod_datasets import UODDataset
35
+
36
+ NORMALIZE = T.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))
37
+
38
+
39
+ def set_dataset_dir(dataset_name, root_dir):
40
+ if dataset_name == "ECSSD":
41
+ dataset_dir = os.path.join(root_dir, "ECSSD")
42
+ img_dir = os.path.join(dataset_dir, "images")
43
+ gt_dir = os.path.join(dataset_dir, "ground_truth_mask")
44
+ scribbles_dir = os.path.join(root_dir, "SCRIBBLES")
45
+
46
+ elif dataset_name == "DUTS-TEST":
47
+ dataset_dir = os.path.join(root_dir, "DUTS-TE")
48
+ img_dir = os.path.join(dataset_dir, "DUTS-TE-Image")
49
+ gt_dir = os.path.join(dataset_dir, "DUTS-TE-Mask")
50
+ scribbles_dir = os.path.join(root_dir, "SCRIBBLES")
51
+
52
+ elif dataset_name == "DUTS-TR":
53
+ dataset_dir = os.path.join(root_dir, "DUTS-TR")
54
+ img_dir = os.path.join(dataset_dir, "DUTS-TR-Image")
55
+ gt_dir = os.path.join(dataset_dir, "DUTS-TR-Mask")
56
+ scribbles_dir = os.path.join(root_dir, "SCRIBBLES")
57
+
58
+ elif dataset_name == "DUT-OMRON":
59
+ dataset_dir = os.path.join(root_dir, "DUT-OMRON")
60
+ img_dir = os.path.join(dataset_dir, "DUT-OMRON-image")
61
+ gt_dir = os.path.join(dataset_dir, "pixelwiseGT-new-PNG")
62
+ scribbles_dir = os.path.join(root_dir, "SCRIBBLES")
63
+
64
+ elif dataset_name == "VOC07":
65
+ dataset_dir = os.path.join(root_dir, "VOC2007")
66
+ img_dir = dataset_dir
67
+ gt_dir = dataset_dir
68
+ scribbles_dir = os.path.join(root_dir, "SCRIBBLES")
69
+
70
+ elif dataset_name == "VOC12":
71
+ dataset_dir = os.path.join(root_dir, "VOC2012")
72
+ img_dir = dataset_dir
73
+ gt_dir = dataset_dir
74
+ scribbles_dir = os.path.join(root_dir, "SCRIBBLES")
75
+
76
+ elif dataset_name == "COCO17":
77
+ dataset_dir = os.path.join(root_dir, "COCO")
78
+ img_dir = dataset_dir
79
+ gt_dir = dataset_dir
80
+ scribbles_dir = os.path.join(root_dir, "SCRIBBLES")
81
+
82
+ elif dataset_name == "ImageNet":
83
+ dataset_dir = os.path.join(root_dir, "ImageNet")
84
+ img_dir = dataset_dir
85
+ gt_dir = dataset_dir
86
+
87
+ else:
88
+ raise ValueError(f"Unknown dataset {dataset_name}")
89
+
90
+ return img_dir, gt_dir, scribbles_dir
91
+
92
+
93
+ def build_dataset(
94
+ root_dir: str,
95
+ dataset_name: str,
96
+ dataset_set: Optional[str] = None,
97
+ for_eval: bool = False,
98
+ config=None,
99
+ evaluation_type="saliency", # uod,
100
+ ):
101
+ """
102
+ Build dataset
103
+ """
104
+
105
+ if evaluation_type == "saliency":
106
+ # training data loaded from here
107
+ img_dir, gt_dir, scribbles_dir = set_dataset_dir(dataset_name, root_dir)
108
+ dataset = PeekabooDataset(
109
+ name=dataset_name,
110
+ img_dir=img_dir,
111
+ gt_dir=gt_dir,
112
+ scribbles_dir=scribbles_dir,
113
+ dataset_set=dataset_set,
114
+ config=config,
115
+ for_eval=for_eval,
116
+ evaluation_type=evaluation_type,
117
+ )
118
+
119
+ elif evaluation_type == "uod":
120
+ assert dataset_name in ["VOC07", "VOC12", "COCO20k"]
121
+ dataset_set = "trainval" if dataset_name in ["VOC07", "VOC12"] else "train"
122
+ no_hards = False
123
+ dataset = UODDataset(
124
+ dataset_name,
125
+ dataset_set,
126
+ root_dir=root_dir,
127
+ remove_hards=no_hards,
128
+ )
129
+
130
+ return dataset
131
+
132
+
133
+ class PeekabooDataset(Dataset):
134
+ def __init__(
135
+ self,
136
+ name: str,
137
+ img_dir: str,
138
+ gt_dir: str,
139
+ scribbles_dir: str,
140
+ dataset_set: Optional[str] = None,
141
+ config=None,
142
+ for_eval: bool = False,
143
+ evaluation_type: str = "saliency",
144
+ ) -> None:
145
+ """
146
+ Args:
147
+ root_dir (string): Directory with all the images.
148
+ transform (callable, optional): Optional transform to be applied
149
+ on a sample.
150
+ """
151
+ self.for_eval = for_eval
152
+ self.use_aug = not for_eval
153
+ self.evaluation_type = evaluation_type
154
+
155
+ assert evaluation_type in ["saliency"]
156
+
157
+ self.name = name
158
+ self.dataset_set = dataset_set
159
+ self.img_dir = img_dir
160
+ self.gt_dir = gt_dir
161
+ self.scribbles_dir = scribbles_dir
162
+
163
+ # if VOC dataset
164
+ self.loader = None
165
+ self.cocoGt = None
166
+
167
+ self.config = config
168
+
169
+ if "VOC" in self.name:
170
+ self.loader = create_VOC_loader(self.img_dir, dataset_set, evaluation_type)
171
+
172
+ # if ImageNet dataset
173
+ elif "ImageNet" in self.name:
174
+ self.loader = torchvision.datasets.ImageNet(
175
+ self.img_dir,
176
+ split=dataset_set,
177
+ transform=None,
178
+ target_transform=None,
179
+ )
180
+
181
+ elif "COCO" in self.name:
182
+ year = int("20" + self.name[-2:])
183
+ annFile = f"/datasets_local/COCO/annotations/instances_{dataset_set}{str(year)}.json"
184
+ self.cocoGt = COCO(annFile)
185
+ self.img_ids = list(sorted(self.cocoGt.getImgIds()))
186
+ self.img_dir = f"/datasets_local/COCO/images/{dataset_set}{str(year)}/"
187
+
188
+ # Transformations
189
+ if self.for_eval:
190
+ (
191
+ full_img_transform,
192
+ no_norm_full_img_transform,
193
+ ) = self.get_init_transformation(isVOC="VOC" in name)
194
+ self.full_img_transform = full_img_transform
195
+ self.no_norm_full_img_transform = no_norm_full_img_transform
196
+
197
+ # Images
198
+ self.list_images = None
199
+ self.list_scribbles = None
200
+ if not "VOC" in self.name and not "COCO" in self.name:
201
+ self.list_images = [
202
+ os.path.join(img_dir, i) for i in sorted(os.listdir(img_dir))
203
+ ]
204
+ # get path to scribbles, high masks are used, see https://github.com/hasibzunair/msl-recognition
205
+ self.list_scribbles = sorted(glob.glob(scribbles_dir + "/*.png"))[::-1][
206
+ :1000
207
+ ] # For heavy masking [::-1]
208
+
209
+ self.ignore_index = -1
210
+ self.mean = NORMALIZE.mean
211
+ self.std = NORMALIZE.std
212
+ self.to_tensor_and_normalize = T.Compose([T.ToTensor(), NORMALIZE])
213
+ self.normalize = NORMALIZE
214
+
215
+ if config is not None and self.use_aug:
216
+ self._set_aug(config)
217
+
218
+ def get_init_transformation(self, isVOC: bool = False):
219
+ if isVOC:
220
+ t = T.Compose(
221
+ [T.PILToTensor(), T.ConvertImageDtype(torch.float), NORMALIZE]
222
+ )
223
+ t_nonorm = T.Compose([T.PILToTensor(), T.ConvertImageDtype(torch.float)])
224
+ return t, t_nonorm
225
+
226
+ else:
227
+ t = T.Compose([T.ToTensor(), NORMALIZE])
228
+ t_nonorm = T.Compose([T.ToTensor()])
229
+ return t, t_nonorm
230
+
231
+ def _set_aug(self, config):
232
+ """
233
+ Set augmentation based on config.
234
+ """
235
+
236
+ photometric_aug = config.training["photometric_aug"]
237
+
238
+ self.cropping_strategy = config.training["cropping_strategy"]
239
+ if self.cropping_strategy == "center_crop":
240
+ self.use_aug = False # default strategy, not considered to be a data aug
241
+ self.scale_range = config.training["scale_range"]
242
+ self.crop_size = config.training["crop_size"]
243
+ self.center_crop_transforms = T.Compose(
244
+ [
245
+ T.CenterCrop((self.crop_size, self.crop_size)),
246
+ T.ToTensor(),
247
+ ]
248
+ )
249
+ self.center_crop_only_transforms = T.Compose(
250
+ [T.CenterCrop((self.crop_size, self.crop_size)), T.PILToTensor()]
251
+ )
252
+
253
+ self.proba_photometric_aug = config.training["proba_photometric_aug"]
254
+
255
+ self.random_color_jitter = False
256
+ self.random_grayscale = False
257
+ self.random_gaussian_blur = False
258
+ if photometric_aug == "color_jitter":
259
+ self.random_color_jitter = True
260
+ elif photometric_aug == "grayscale":
261
+ self.random_grayscale = True
262
+ elif photometric_aug == "gaussian_blur":
263
+ self.random_gaussian_blur = True
264
+
265
+ def _preprocess_data_aug(
266
+ self,
267
+ image: Image.Image,
268
+ mask: Image.Image,
269
+ ignore_index: Optional[int] = None,
270
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
271
+ """Prepare data in a proper form for either training (data augmentation) or validation."""
272
+
273
+ # resize to base size
274
+ image = resize(
275
+ image,
276
+ size=self.crop_size,
277
+ edge="shorter",
278
+ interpolation="bilinear",
279
+ )
280
+ mask = resize(
281
+ mask,
282
+ size=self.crop_size,
283
+ edge="shorter",
284
+ interpolation="bilinear",
285
+ )
286
+
287
+ if not isinstance(mask, torch.Tensor):
288
+ mask: torch.Tensor = torch.tensor(np.array(mask))
289
+
290
+ random_scale_range = None
291
+ random_crop_size = None
292
+ random_hflip_p = None
293
+ if self.cropping_strategy == "random_scale":
294
+ random_scale_range = self.scale_range
295
+ elif self.cropping_strategy == "random_crop":
296
+ random_crop_size = self.crop_size
297
+ elif self.cropping_strategy == "random_hflip":
298
+ random_hflip_p = 0.5
299
+ elif self.cropping_strategy == "random_crop_and_hflip":
300
+ random_hflip_p = 0.5
301
+ random_crop_size = self.crop_size
302
+
303
+ if random_crop_size or random_hflip_p or random_scale_range:
304
+ image, mask = geometric_augmentations(
305
+ image=image,
306
+ mask=mask,
307
+ random_scale_range=random_scale_range,
308
+ random_crop_size=random_crop_size,
309
+ ignore_index=ignore_index,
310
+ random_hflip_p=random_hflip_p,
311
+ )
312
+
313
+ if random_scale_range:
314
+ # resize to (self.crop_size, self.crop_size)
315
+ image = resize(
316
+ image,
317
+ size=self.crop_size,
318
+ interpolation="bilinear",
319
+ )
320
+ mask = resize(
321
+ mask,
322
+ size=(self.crop_size, self.crop_size),
323
+ interpolation="bilinear",
324
+ )
325
+
326
+ image = photometric_augmentations(
327
+ image,
328
+ random_color_jitter=self.random_color_jitter,
329
+ random_grayscale=self.random_grayscale,
330
+ random_gaussian_blur=self.random_gaussian_blur,
331
+ proba_photometric_aug=self.proba_photometric_aug,
332
+ )
333
+
334
+ # to tensor + normalize image
335
+ image = self.to_tensor_and_normalize(image)
336
+
337
+ return image, mask
338
+
339
+ def __len__(self) -> int:
340
+ if "VOC" in self.name:
341
+ return len(self.loader)
342
+ elif "ImageNet" in self.name:
343
+ return len(self.loader)
344
+ elif "COCO" in self.name:
345
+ return len(self.img_ids)
346
+ return len(self.list_images)
347
+
348
+ def _apply_center_crop(
349
+ self, image: Image.Image, mask: Union[Image.Image, np.ndarray, torch.Tensor]
350
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
351
+ img_t = self.center_crop_transforms(image)
352
+ # need to normalize image
353
+ img_t = self.normalize(img_t)
354
+ mask_gt = self.center_crop_transforms(mask).squeeze()
355
+ return img_t, mask_gt
356
+
357
+ def _preprocess_scribble(self, img, img_size):
358
+ transform = T.Compose(
359
+ [
360
+ T.Resize(img_size, BICUBIC),
361
+ T.CenterCrop(img_size),
362
+ T.ToTensor(),
363
+ ]
364
+ )
365
+ return transform(img)
366
+
367
+ def __getitem__(self, idx, get_mask_gt=True):
368
+ if "VOC" in self.name:
369
+ img, gt_labels = self.loader[idx]
370
+ if self.evaluation_type == "uod":
371
+ gt_labels, _ = get_voc_detection_gt(gt_labels, remove_hards=False)
372
+ elif self.evaluation_type == "saliency":
373
+ mask_gt = create_gt_masks_if_voc(gt_labels)
374
+ img_path = self.loader.images[idx]
375
+
376
+ elif "ImageNet" in self.name:
377
+ img, _ = self.loader[idx]
378
+ img_path = self.loader.imgs[idx][0]
379
+ # empty mask since no gt mask, only class label
380
+ zeros = np.zeros(np.array(img).shape[:2])
381
+ mask_gt = Image.fromarray(zeros)
382
+
383
+ elif "COCO" in self.name:
384
+ img_id = self.img_ids[idx]
385
+
386
+ path = self.cocoGt.loadImgs(img_id)[0]["file_name"]
387
+ img = Image.open(os.path.join(self.img_dir, path)).convert("RGB")
388
+ _ = self.cocoGt.loadAnns(self.cocoGt.getAnnIds(id))
389
+ img_path = self.img_ids[idx] # What matters most is the id for eval
390
+
391
+ # empty mask since no gt mask, only class label
392
+ zeros = np.zeros(np.array(img).shape[:2])
393
+ mask_gt = Image.fromarray(zeros)
394
+
395
+ # For all others
396
+ else:
397
+ img_path = self.list_images[idx]
398
+ scribble_path = self.list_scribbles[random.randint(0, 950)]
399
+
400
+ # read image
401
+ with open(img_path, "rb") as f:
402
+ img = Image.open(f)
403
+ img = img.convert("RGB")
404
+ im_name = img_path.split("/")[-1]
405
+ mask_gt = Image.open(
406
+ os.path.join(self.gt_dir, im_name.replace(".jpg", ".png"))
407
+ ).convert("L")
408
+
409
+ if self.for_eval:
410
+ img_t = self.full_img_transform(img)
411
+ img_init = self.no_norm_full_img_transform(img)
412
+
413
+ if self.evaluation_type == "saliency":
414
+ mask_gt = torch.tensor(np.array(mask_gt)).squeeze()
415
+ mask_gt = np.array(mask_gt)
416
+ mask_gt = mask_gt == 255
417
+ mask_gt = torch.tensor(mask_gt)
418
+ else:
419
+ if self.use_aug:
420
+ img_t, mask_gt = self._preprocess_data_aug(
421
+ image=img, mask=mask_gt, ignore_index=self.ignore_index
422
+ )
423
+ mask_gt = np.array(mask_gt)
424
+ mask_gt = mask_gt == 255
425
+ mask_gt = torch.tensor(mask_gt)
426
+ else:
427
+ # no data aug
428
+ img_t, mask_gt = self._apply_center_crop(image=img, mask=mask_gt)
429
+ gt_labels = self.center_crop_only_transforms(gt_labels).squeeze()
430
+ mask_gt = np.asarray(mask_gt, np.int64)
431
+ mask_gt = mask_gt == 1
432
+ mask_gt = torch.tensor(mask_gt)
433
+
434
+ img_init = unnormalize(img_t)
435
+
436
+ if not get_mask_gt:
437
+ mask_gt = None
438
+
439
+ if self.evaluation_type == "uod":
440
+ gt_labels = torch.tensor(gt_labels)
441
+ mask_gt = gt_labels
442
+
443
+ # read scribble
444
+ with open(scribble_path, "rb") as f:
445
+ scribble = Image.open(f).convert("P")
446
+ scribble = self._preprocess_scribble(scribble, img_t.shape[1])
447
+ scribble = (scribble > 0).float() # threshold to [0,1]
448
+ scribble = torch.max(scribble) - scribble # inverted scribble
449
+
450
+ # create masked input image with scribble when training
451
+ if not self.for_eval:
452
+ masked_img_t = img_t * scribble
453
+ masked_img_init = unnormalize(masked_img_t)
454
+ else:
455
+ masked_img_t = img_t
456
+ masked_img_init = img_init
457
+
458
+ # returns the
459
+ # image, masked image, scribble,
460
+ # un-normalized image, un-normalized masked image
461
+ # ground truth mask, image path
462
+ return (
463
+ img_t,
464
+ masked_img_t,
465
+ scribble,
466
+ img_init,
467
+ masked_img_init,
468
+ mask_gt,
469
+ img_path,
470
+ )
471
+
472
+ def fullimg_mode(self):
473
+ self.val_full_image = True
474
+
475
+ def training_mode(self):
476
+ self.val_full_image = False
datasets/geometric_transforms.py ADDED
@@ -0,0 +1,160 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Code adapted from SelfMask: https://github.com/NoelShin/selfmask
3
+ """
4
+
5
+ from random import randint, random, uniform
6
+ from typing import Optional, Tuple, Union
7
+
8
+ import numpy as np
9
+ import torch
10
+ import torchvision.transforms.functional as TF
11
+ from PIL import Image
12
+ from torchvision.transforms.functional import InterpolationMode as IM
13
+
14
+
15
+ def random_crop(
16
+ image: Union[Image.Image, np.ndarray, torch.Tensor],
17
+ crop_size: Tuple[int, int], # (h, w)
18
+ fill: Union[int, Tuple[int, int, int]], # an unsigned integer or RGB,
19
+ offset: Optional[Tuple[int, int]] = None, # (top, left) coordinate of a crop
20
+ ):
21
+ assert type(crop_size) in (tuple, list) and len(crop_size) == 2
22
+
23
+ if isinstance(image, np.ndarray):
24
+ image = torch.tensor(image)
25
+ h, w = image.shape[-2:]
26
+ elif isinstance(image, Image.Image):
27
+ w, h = image.size
28
+ elif isinstance(image, torch.Tensor):
29
+ h, w = image.shape[-2:]
30
+ else:
31
+ raise TypeError(type(image))
32
+
33
+ pad_h, pad_w = max(crop_size[0] - h, 0), max(crop_size[1] - w, 0)
34
+
35
+ image = TF.pad(image, [0, 0, pad_w, pad_h], fill=fill, padding_mode="constant")
36
+
37
+ if isinstance(image, Image.Image):
38
+ w, h = image.size
39
+ else:
40
+ h, w = image.shape[-2:]
41
+
42
+ if offset is None:
43
+ offset = (randint(0, h - crop_size[0]), randint(0, w - crop_size[1]))
44
+
45
+ image = TF.crop(
46
+ image, top=offset[0], left=offset[1], height=crop_size[0], width=crop_size[1]
47
+ )
48
+ return image, offset
49
+
50
+
51
+ def compute_size(
52
+ input_size: Tuple[int, int], output_size: int, edge: str # h, w
53
+ ) -> Tuple[int, int]:
54
+ assert edge in ["shorter", "longer"]
55
+ h, w = input_size
56
+
57
+ if edge == "longer":
58
+ if w > h:
59
+ h = int(float(h) / w * output_size)
60
+ w = output_size
61
+ else:
62
+ w = int(float(w) / h * output_size)
63
+ h = output_size
64
+ assert w <= output_size and h <= output_size
65
+
66
+ else:
67
+ if w > h:
68
+ w = int(float(w) / h * output_size)
69
+ h = output_size
70
+ else:
71
+ h = int(float(h) / w * output_size)
72
+ w = output_size
73
+ assert w >= output_size and h >= output_size
74
+ return h, w
75
+
76
+
77
+ def resize(
78
+ image: Union[Image.Image, np.ndarray, torch.Tensor],
79
+ size: Union[int, Tuple[int, int]],
80
+ interpolation: str,
81
+ edge: str = "both",
82
+ ) -> Union[Image.Image, torch.Tensor]:
83
+ """
84
+ :param image: an image to be resized
85
+ :param size: a resulting image size
86
+ :param interpolation: sampling mode. ["nearest", "bilinear", "bicubic"]
87
+ :param edge: Default: "both"
88
+ No-op if a size is given as a tuple (h, w).
89
+ If set to "both", resize both height and width to the specified size.
90
+ If set to "shorter", resize the shorter edge to the specified size keeping the aspect ratio.
91
+ If set to "longer", resize the longer edge to the specified size keeping the aspect ratio.
92
+ :return: a resized image
93
+ """
94
+ assert interpolation in ["nearest", "bilinear", "bicubic"], ValueError(
95
+ interpolation
96
+ )
97
+ assert edge in ["both", "shorter", "longer"], ValueError(edge)
98
+ interpolation = {
99
+ "nearest": IM.NEAREST,
100
+ "bilinear": IM.BILINEAR,
101
+ "bicubic": IM.BICUBIC,
102
+ }[interpolation]
103
+
104
+ if type(image) == torch.Tensor:
105
+ image = image.clone().detach()
106
+ elif type(image) == np.ndarray:
107
+ image = torch.from_numpy(image)
108
+
109
+ if type(size) is tuple:
110
+ if type(image) == torch.Tensor and len(image.shape) == 2:
111
+ image = TF.resize(
112
+ image.unsqueeze(dim=0), size=size, interpolation=interpolation
113
+ ).squeeze(dim=0)
114
+ else:
115
+ image = TF.resize(image, size=size, interpolation=interpolation)
116
+
117
+ else:
118
+ if edge == "both":
119
+ image = TF.resize(image, size=[size, size], interpolation=interpolation)
120
+
121
+ else:
122
+ if isinstance(image, Image.Image):
123
+ w, h = image.size
124
+ else:
125
+ h, w = image.shape[-2:]
126
+ rh, rw = compute_size(input_size=(h, w), output_size=size, edge=edge)
127
+ image = TF.resize(image, size=[rh, rw], interpolation=interpolation)
128
+ return image
129
+
130
+
131
+ def random_scale(
132
+ image: Union[Image.Image, np.ndarray, torch.Tensor],
133
+ random_scale_range: Tuple[float, float],
134
+ mask: Optional[Union[Image.Image, np.ndarray, torch.Tensor]] = None,
135
+ ):
136
+ scale = uniform(*random_scale_range)
137
+ if isinstance(image, Image.Image):
138
+ w, h = image.size
139
+ else:
140
+ h, w = image.shape[-2:]
141
+ w_rs, h_rs = int(w * scale), int(h * scale)
142
+ image: Image.Image = resize(image, size=(h_rs, w_rs), interpolation="bilinear")
143
+ if mask is not None:
144
+ mask = resize(mask, size=(h_rs, w_rs), interpolation="nearest")
145
+ return image, mask
146
+
147
+
148
+ def random_hflip(
149
+ image: Union[Image.Image, np.ndarray, torch.Tensor],
150
+ p: float,
151
+ mask: Optional[Union[np.ndarray, torch.Tensor]] = None,
152
+ ):
153
+ assert 0.0 <= p <= 1.0, ValueError(random_hflip)
154
+
155
+ # Return a random floating point number in the range [0.0, 1.0).
156
+ if random() > p:
157
+ image = TF.hflip(image)
158
+ if mask is not None:
159
+ mask = TF.hflip(mask)
160
+ return image, mask
datasets/uod_datasets.py ADDED
@@ -0,0 +1,396 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2021 - Valeo Comfort and Driving Assistance - Oriane Siméoni @ valeo.ai
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ """
16
+ Code adapted from previous method LOST: https://github.com/valeoai/LOST
17
+ """
18
+
19
+ import os
20
+ import math
21
+ import torch
22
+ import json
23
+ import torchvision
24
+ import numpy as np
25
+ import skimage.io
26
+
27
+ from PIL import Image
28
+ from tqdm import tqdm
29
+ from torchvision import transforms as pth_transforms
30
+
31
+ # Image transformation applied to all images
32
+ transform = pth_transforms.Compose(
33
+ [
34
+ pth_transforms.ToTensor(),
35
+ pth_transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
36
+ ]
37
+ )
38
+
39
+
40
+ class ImageDataset:
41
+ def __init__(self, image_path):
42
+
43
+ self.image_path = image_path
44
+ self.name = image_path.split("/")[-1]
45
+
46
+ # Read the image
47
+ with open(image_path, "rb") as f:
48
+ img = Image.open(f)
49
+ img = img.convert("RGB")
50
+
51
+ # Build a dataloader
52
+ img = transform(img)
53
+ self.dataloader = [[img, image_path]]
54
+
55
+ def get_image_name(self, *args, **kwargs):
56
+ return self.image_path.split("/")[-1].split(".")[0]
57
+
58
+ def load_image(self, *args, **kwargs):
59
+ return skimage.io.imread(self.image_path)
60
+
61
+
62
+ class UODDataset:
63
+ def __init__(
64
+ self,
65
+ dataset_name,
66
+ dataset_set,
67
+ root_dir,
68
+ remove_hards: bool = False,
69
+ ):
70
+ """
71
+ Build the dataloader
72
+ """
73
+
74
+ self.dataset_name = dataset_name
75
+ self.set = dataset_set
76
+ self.root_dir = root_dir
77
+
78
+ if dataset_name == "VOC07":
79
+ self.root_path = f"{root_dir}/VOC2007"
80
+ self.year = "2007"
81
+ elif dataset_name == "VOC12":
82
+ self.root_path = f"{root_dir}/VOC2012"
83
+ self.year = "2012"
84
+ elif dataset_name == "COCO20k":
85
+ self.year = "2014"
86
+ self.root_path = f"{root_dir}/COCO/images/{dataset_set}{self.year}"
87
+ self.sel20k = "data/coco_20k_filenames.txt"
88
+ # new JSON file constructed based on COCO train2014 gt
89
+ self.all_annfile = f"{root_dir}/COCO/annotations/instances_train2014.json"
90
+ self.annfile = (
91
+ f"{root_dir}/COCO/annotations/instances_train2014_sel20k.json"
92
+ )
93
+ if not os.path.exists(self.annfile):
94
+ select_coco_20k(self.sel20k, self.all_annfile)
95
+ else:
96
+ raise ValueError("Unknown dataset.")
97
+
98
+ if not os.path.exists(self.root_path):
99
+ raise ValueError("Please follow the README to setup the datasets.")
100
+
101
+ self.name = f"{self.dataset_name}_{self.set}"
102
+
103
+ # Build the dataloader
104
+ # import pdb; pdb.set_trace()
105
+
106
+ if "VOC" in dataset_name:
107
+ self.dataloader = torchvision.datasets.VOCDetection(
108
+ self.root_path,
109
+ year=self.year,
110
+ image_set=self.set,
111
+ transform=transform,
112
+ download=False,
113
+ )
114
+ elif "COCO20k" == dataset_name:
115
+ self.dataloader = torchvision.datasets.CocoDetection(
116
+ self.root_path, annFile=self.annfile, transform=transform
117
+ )
118
+ else:
119
+ raise ValueError("Unknown dataset.")
120
+
121
+ # Set hards images that are not included
122
+ self.remove_hards = remove_hards
123
+ self.hards = []
124
+ if remove_hards:
125
+ self.name += f"-nohards"
126
+ self.hards = self.get_hards()
127
+ print(f"Nb images discarded {len(self.hards)}")
128
+
129
+ def __len__(self) -> int:
130
+ return len(self.dataloader)
131
+
132
+ def load_image(self, im_name):
133
+ """
134
+ Load the image corresponding to the im_name
135
+ """
136
+ if "VOC" in self.dataset_name:
137
+ image = skimage.io.imread(
138
+ f"{self.root_dir}/VOC{self.year}/JPEGImages/{im_name}"
139
+ )
140
+ elif "COCO" in self.dataset_name:
141
+ im_path = self.path_20k[self.sel_20k.index(im_name)]
142
+ image = skimage.io.imread(f"{self.root_dir}/COCO/images/{im_path}")
143
+ else:
144
+ raise ValueError("Unkown dataset.")
145
+ return image
146
+
147
+ def get_image_name(self, inp):
148
+ """
149
+ Return the image name
150
+ """
151
+ if "VOC" in self.dataset_name:
152
+ im_name = inp["annotation"]["filename"]
153
+ elif "COCO" in self.dataset_name:
154
+ im_name = str(inp[0]["image_id"])
155
+
156
+ return im_name
157
+
158
+ def extract_gt(self, targets, im_name):
159
+ if "VOC" in self.dataset_name:
160
+ return extract_gt_VOC(targets, remove_hards=self.remove_hards)
161
+ elif "COCO" in self.dataset_name:
162
+ return extract_gt_COCO(targets, remove_iscrowd=True)
163
+ else:
164
+ raise ValueError("Unknown dataset")
165
+
166
+ def extract_classes(self):
167
+ if "VOC" in self.dataset_name:
168
+ cls_path = f"classes_{self.set}_{self.year}.txt"
169
+ elif "COCO" in self.dataset_name:
170
+ cls_path = f"classes_{self.dataset}_{self.set}_{self.year}.txt"
171
+
172
+ # Load if exists
173
+ if os.path.exists(cls_path):
174
+ all_classes = []
175
+ with open(cls_path, "r") as f:
176
+ for line in f:
177
+ all_classes.append(line.strip())
178
+ else:
179
+ print("Extract all classes from the dataset")
180
+ if "VOC" in self.dataset_name:
181
+ all_classes = self.extract_classes_VOC()
182
+ elif "COCO" in self.dataset_name:
183
+ all_classes = self.extract_classes_COCO()
184
+
185
+ with open(cls_path, "w") as f:
186
+ for s in all_classes:
187
+ f.write(str(s) + "\n")
188
+
189
+ return all_classes
190
+
191
+ def extract_classes_VOC(self):
192
+ all_classes = []
193
+ for im_id, inp in enumerate(tqdm(self.dataloader)):
194
+ objects = inp[1]["annotation"]["object"]
195
+
196
+ for o in range(len(objects)):
197
+ if objects[o]["name"] not in all_classes:
198
+ all_classes.append(objects[o]["name"])
199
+
200
+ return all_classes
201
+
202
+ def extract_classes_COCO(self):
203
+ all_classes = []
204
+ for im_id, inp in enumerate(tqdm(self.dataloader)):
205
+ objects = inp[1]
206
+
207
+ for o in range(len(objects)):
208
+ if objects[o]["category_id"] not in all_classes:
209
+ all_classes.append(objects[o]["category_id"])
210
+
211
+ return all_classes
212
+
213
+ def get_hards(self):
214
+ hard_path = "datasets/hard_%s_%s_%s.txt" % (
215
+ self.dataset_name,
216
+ self.set,
217
+ self.year,
218
+ )
219
+ if os.path.exists(hard_path):
220
+ hards = []
221
+ with open(hard_path, "r") as f:
222
+ for line in f:
223
+ hards.append(int(line.strip()))
224
+ else:
225
+ print("Discover hard images that should be discarded")
226
+
227
+ if "VOC" in self.dataset_name:
228
+ # set the hards
229
+ hards = discard_hard_voc(self.dataloader)
230
+
231
+ with open(hard_path, "w") as f:
232
+ for s in hards:
233
+ f.write(str(s) + "\n")
234
+
235
+ return hards
236
+
237
+
238
+ def discard_hard_voc(dataloader):
239
+ hards = []
240
+ for im_id, inp in enumerate(tqdm(dataloader)):
241
+ objects = inp[1]["annotation"]["object"]
242
+ nb_obj = len(objects)
243
+
244
+ hard = np.zeros(nb_obj)
245
+ for i, o in enumerate(range(nb_obj)):
246
+ hard[i] = (
247
+ 1
248
+ if (objects[o]["truncated"] == "1" or objects[o]["difficult"] == "1")
249
+ else 0
250
+ )
251
+
252
+ # all images with only truncated or difficult objects
253
+ if np.sum(hard) == nb_obj:
254
+ hards.append(im_id)
255
+ return hards
256
+
257
+
258
+ def extract_gt_COCO(targets, remove_iscrowd=True):
259
+ objects = targets
260
+ nb_obj = len(objects)
261
+
262
+ gt_bbxs = []
263
+ gt_clss = []
264
+ for o in range(nb_obj):
265
+ # Remove iscrowd boxes
266
+ if remove_iscrowd and objects[o]["iscrowd"] == 1:
267
+ continue
268
+ gt_cls = objects[o]["category_id"]
269
+ gt_clss.append(gt_cls)
270
+ bbx = objects[o]["bbox"]
271
+ x1y1x2y2 = [bbx[0], bbx[1], bbx[0] + bbx[2], bbx[1] + bbx[3]]
272
+ x1y1x2y2 = [int(round(x)) for x in x1y1x2y2]
273
+ gt_bbxs.append(x1y1x2y2)
274
+
275
+ return np.asarray(gt_bbxs), gt_clss
276
+
277
+
278
+ def extract_gt_VOC(targets, remove_hards=False):
279
+ objects = targets["annotation"]["object"]
280
+ nb_obj = len(objects)
281
+
282
+ gt_bbxs = []
283
+ gt_clss = []
284
+ for o in range(nb_obj):
285
+ if remove_hards and (
286
+ objects[o]["truncated"] == "1" or objects[o]["difficult"] == "1"
287
+ ):
288
+ continue
289
+ gt_cls = objects[o]["name"]
290
+ gt_clss.append(gt_cls)
291
+ obj = objects[o]["bndbox"]
292
+ x1y1x2y2 = [
293
+ int(obj["xmin"]),
294
+ int(obj["ymin"]),
295
+ int(obj["xmax"]),
296
+ int(obj["ymax"]),
297
+ ]
298
+ # Original annotations are integers in the range [1, W or H]
299
+ # Assuming they mean 1-based pixel indices (inclusive),
300
+ # a box with annotation (xmin=1, xmax=W) covers the whole image.
301
+ # In coordinate space this is represented by (xmin=0, xmax=W)
302
+ x1y1x2y2[0] -= 1
303
+ x1y1x2y2[1] -= 1
304
+ gt_bbxs.append(x1y1x2y2)
305
+
306
+ return np.asarray(gt_bbxs), gt_clss
307
+
308
+
309
+ def bbox_iou(box1, box2, x1y1x2y2=True, GIoU=False, DIoU=False, CIoU=False, eps=1e-7):
310
+ # https://github.com/ultralytics/yolov5/blob/develop/utils/general.py
311
+ # Returns the IoU of box1 to box2. box1 is 4, box2 is nx4
312
+ box2 = box2.T
313
+
314
+ # Get the coordinates of bounding boxes
315
+ if x1y1x2y2: # x1, y1, x2, y2 = box1
316
+ b1_x1, b1_y1, b1_x2, b1_y2 = box1[0], box1[1], box1[2], box1[3]
317
+ b2_x1, b2_y1, b2_x2, b2_y2 = box2[0], box2[1], box2[2], box2[3]
318
+ else: # transform from xywh to xyxy
319
+ b1_x1, b1_x2 = box1[0] - box1[2] / 2, box1[0] + box1[2] / 2
320
+ b1_y1, b1_y2 = box1[1] - box1[3] / 2, box1[1] + box1[3] / 2
321
+ b2_x1, b2_x2 = box2[0] - box2[2] / 2, box2[0] + box2[2] / 2
322
+ b2_y1, b2_y2 = box2[1] - box2[3] / 2, box2[1] + box2[3] / 2
323
+
324
+ # Intersection area
325
+ inter = (torch.min(b1_x2, b2_x2) - torch.max(b1_x1, b2_x1)).clamp(0) * (
326
+ torch.min(b1_y2, b2_y2) - torch.max(b1_y1, b2_y1)
327
+ ).clamp(0)
328
+
329
+ # Union Area
330
+ w1, h1 = b1_x2 - b1_x1, b1_y2 - b1_y1 + eps
331
+ w2, h2 = b2_x2 - b2_x1, b2_y2 - b2_y1 + eps
332
+ union = w1 * h1 + w2 * h2 - inter + eps
333
+
334
+ iou = inter / union
335
+ if GIoU or DIoU or CIoU:
336
+ cw = torch.max(b1_x2, b2_x2) - torch.min(
337
+ b1_x1, b2_x1
338
+ ) # convex (smallest enclosing box) width
339
+ ch = torch.max(b1_y2, b2_y2) - torch.min(b1_y1, b2_y1) # convex height
340
+ if CIoU or DIoU: # Distance or Complete IoU https://arxiv.org/abs/1911.08287v1
341
+ c2 = cw**2 + ch**2 + eps # convex diagonal squared
342
+ rho2 = (
343
+ (b2_x1 + b2_x2 - b1_x1 - b1_x2) ** 2
344
+ + (b2_y1 + b2_y2 - b1_y1 - b1_y2) ** 2
345
+ ) / 4 # center distance squared
346
+ if DIoU:
347
+ return iou - rho2 / c2 # DIoU
348
+ elif (
349
+ CIoU
350
+ ): # https://github.com/Zzh-tju/DIoU-SSD-pytorch/blob/master/utils/box/box_utils.py#L47
351
+ v = (4 / math.pi**2) * torch.pow(
352
+ torch.atan(w2 / h2) - torch.atan(w1 / h1), 2
353
+ )
354
+ with torch.no_grad():
355
+ alpha = v / (v - iou + (1 + eps))
356
+ return iou - (rho2 / c2 + v * alpha) # CIoU
357
+ else: # GIoU https://arxiv.org/pdf/1902.09630.pdf
358
+ c_area = cw * ch + eps # convex area
359
+ return iou - (c_area - union) / c_area # GIoU
360
+ else:
361
+ return iou # IoU
362
+
363
+
364
+ def select_coco_20k(sel_file, all_annotations_file):
365
+ print("Building COCO 20k dataset.")
366
+
367
+ # load all annotations
368
+ with open(all_annotations_file, "r") as f:
369
+ train2014 = json.load(f)
370
+
371
+ # load selected images
372
+ with open(sel_file, "r") as f:
373
+ sel_20k = f.readlines()
374
+ sel_20k = [s.replace("\n", "") for s in sel_20k]
375
+ im20k = [str(int(s.split("_")[-1].split(".")[0])) for s in sel_20k]
376
+
377
+ new_anno = []
378
+ new_images = []
379
+
380
+ for i in tqdm(im20k):
381
+ new_anno.extend(
382
+ [a for a in train2014["annotations"] if a["image_id"] == int(i)]
383
+ )
384
+ new_images.extend([a for a in train2014["images"] if a["id"] == int(i)])
385
+
386
+ train2014_20k = {}
387
+ train2014_20k["images"] = new_images
388
+ train2014_20k["annotations"] = new_anno
389
+ train2014_20k["categories"] = train2014["categories"]
390
+
391
+ with open(
392
+ "datasets_local/COCO/annotations/instances_train2014_sel20k.json", "w"
393
+ ) as outfile:
394
+ json.dump(train2014_20k, outfile)
395
+
396
+ print("Done.")
datasets/utils.py ADDED
@@ -0,0 +1,45 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import torch
3
+ from PIL import Image
4
+ from torchvision import transforms as T
5
+
6
+ NORMALIZE = T.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))
7
+
8
+
9
+ class GaussianBlur:
10
+ """
11
+ Code borrowed from SelfMask: https://github.com/NoelShin/selfmask
12
+ """
13
+
14
+ # Implements Gaussian blur as described in the SimCLR paper
15
+ def __init__(self, kernel_size: float, min: float = 0.1, max: float = 2.0) -> None:
16
+ self.min = min
17
+ self.max = max
18
+ # kernel size is set to be 10% of the image height/width
19
+ self.kernel_size = kernel_size
20
+
21
+ def __call__(self, sample: Image.Image, random_gaussian_blur_p: float):
22
+ sample = np.array(sample)
23
+
24
+ # blur the image with a 50% chance
25
+ prob = np.random.random_sample()
26
+
27
+ if prob < 0.5:
28
+ import cv2
29
+
30
+ sigma = (self.max - self.min) * np.random.random_sample() + self.min
31
+ sample = cv2.GaussianBlur(
32
+ sample, (self.kernel_size, self.kernel_size), sigma
33
+ )
34
+ return sample
35
+
36
+
37
+ def unnormalize(image, mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)):
38
+ """
39
+ Code borrowed from STEGO: https://github.com/mhamilton723/STEGO
40
+ """
41
+ image2 = torch.clone(image)
42
+ for t, m, s in zip(image2, mean, std):
43
+ t.mul_(s).add_(m)
44
+
45
+ return image2
demo.py ADDED
@@ -0,0 +1,118 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Code for Peekaboo
2
+ # Author: Hasib Zunair
3
+ # Modified from https://github.com/valeoai/FOUND, see license below.
4
+
5
+ # Copyright 2022 - Valeo Comfort and Driving Assistance - Oriane Siméoni @ valeo.ai
6
+ #
7
+ # Licensed under the Apache License, Version 2.0 (the "License");
8
+ # you may not use this file except in compliance with the License.
9
+ # You may obtain a copy of the License at
10
+ #
11
+ # http://www.apache.org/licenses/LICENSE-2.0
12
+ #
13
+ # Unless required by applicable law or agreed to in writing, software
14
+ # distributed under the License is distributed on an "AS IS" BASIS,
15
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
16
+ # See the License for the specific language governing permissions and
17
+ # limitations under the License.
18
+
19
+ """Visualize model predictions"""
20
+
21
+ import os
22
+ import torch
23
+ import argparse
24
+ import torch.nn as nn
25
+ import torch.nn.functional as F
26
+ import matplotlib.pyplot as plt
27
+
28
+ from PIL import Image
29
+ from model import PeekabooModel
30
+ from misc import load_config
31
+ from torchvision import transforms as T
32
+
33
+ NORMALIZE = T.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))
34
+
35
+ if __name__ == "__main__":
36
+ parser = argparse.ArgumentParser(
37
+ description="Evaluation of Peekaboo",
38
+ formatter_class=argparse.ArgumentDefaultsHelpFormatter,
39
+ )
40
+
41
+ parser.add_argument(
42
+ "--img-path",
43
+ type=str,
44
+ default="data/examples/VOC_000030.jpg",
45
+ help="Image path.",
46
+ )
47
+ parser.add_argument(
48
+ "--model-weights",
49
+ type=str,
50
+ default="data/weights/peekaboo_decoder_weights_niter500.pt",
51
+ )
52
+ parser.add_argument(
53
+ "--config",
54
+ type=str,
55
+ default="configs/peekaboo_DUTS-TR.yaml",
56
+ )
57
+ parser.add_argument(
58
+ "--output-dir",
59
+ type=str,
60
+ default="outputs",
61
+ )
62
+ args = parser.parse_args()
63
+
64
+ # Saving dir
65
+ if not os.path.exists(args.output_dir):
66
+ os.makedirs(args.output_dir)
67
+
68
+ # Configuration
69
+ config, _ = load_config(args.config)
70
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
71
+
72
+ # Load the model
73
+ model = PeekabooModel(
74
+ vit_model=config.model["pre_training"],
75
+ vit_arch=config.model["arch"],
76
+ vit_patch_size=config.model["patch_size"],
77
+ enc_type_feats=config.peekaboo["feats"],
78
+ )
79
+ # Load weights
80
+ model.decoder_load_weights(args.model_weights)
81
+ model.eval()
82
+ print(f"Model {args.model_weights} loaded correctly.")
83
+
84
+ # Load the image
85
+ with open(args.img_path, "rb") as f:
86
+ img = Image.open(f)
87
+ img = img.convert("RGB")
88
+
89
+ t = T.Compose([T.ToTensor(), NORMALIZE])
90
+ img_t = t(img)[None, :, :, :]
91
+ inputs = img_t.to(device)
92
+
93
+ # Forward step
94
+ with torch.no_grad():
95
+ preds = model(inputs, for_eval=True)
96
+
97
+ sigmoid = nn.Sigmoid()
98
+ h, w = img_t.shape[-2:]
99
+ preds_up = F.interpolate(
100
+ preds, scale_factor=model.vit_patch_size, mode="bilinear", align_corners=False
101
+ )[..., :h, :w]
102
+ preds_up = (sigmoid(preds_up.detach()) > 0.5).squeeze(0).float()
103
+
104
+ plt.figure()
105
+ plt.imshow(img)
106
+ plt.imshow(
107
+ preds_up.cpu().squeeze().numpy(), "gray", interpolation="none", alpha=0.5
108
+ )
109
+ plt.axis("off")
110
+ img_name = args.img_path
111
+ img_name = img_name.split("/")[-1].split(".")[0]
112
+ plt.savefig(
113
+ os.path.join(args.output_dir, f"{img_name}-peekaboo.png"),
114
+ bbox_inches="tight",
115
+ pad_inches=0,
116
+ )
117
+ plt.close()
118
+ print(f"Saved model prediction.")
dino ADDED
@@ -0,0 +1 @@
 
 
1
+ Subproject commit 7c446df5b9f45747937fb0d72314eb9f7b66930a
environment.yml ADDED
@@ -0,0 +1,575 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Environment used for this work
2
+ name: peekaboo
3
+ channels:
4
+ - defaults
5
+ - conda-forge
6
+ dependencies:
7
+ - _libgcc_mutex=0.1=main
8
+ - _openmp_mutex=5.1=1_gnu
9
+ - abseil-cpp=20211102.0=hd4dd3e8_0
10
+ - aiobotocore=2.5.0=py38h06a4308_0
11
+ - aiofiles=22.1.0=py38h06a4308_0
12
+ - aiohttp=3.8.5=py38h5eee18b_0
13
+ - aioitertools=0.7.1=pyhd3eb1b0_0
14
+ - aiosignal=1.2.0=pyhd3eb1b0_0
15
+ - aiosqlite=0.18.0=py38h06a4308_0
16
+ - alabaster=0.7.12=pyhd3eb1b0_0
17
+ - anaconda=2023.09=py38_mkl_1
18
+ - aom=3.6.0=h6a678d5_0
19
+ - appdirs=1.4.4=pyhd3eb1b0_0
20
+ - argon2-cffi=21.3.0=pyhd3eb1b0_0
21
+ - argon2-cffi-bindings=21.2.0=py38h7f8727e_0
22
+ - arrow=1.2.3=py38h06a4308_1
23
+ - arrow-cpp=11.0.0=h374c478_2
24
+ - astroid=2.14.2=py38h06a4308_0
25
+ - astropy=5.1=py38h7deecbd_0
26
+ - asttokens=2.0.5=pyhd3eb1b0_0
27
+ - async-timeout=4.0.2=py38h06a4308_0
28
+ - atomicwrites=1.4.0=py_0
29
+ - attrs=22.1.0=py38h06a4308_0
30
+ - automat=20.2.0=py_0
31
+ - autopep8=1.6.0=pyhd3eb1b0_1
32
+ - aws-c-common=0.6.8=h5eee18b_1
33
+ - aws-c-event-stream=0.1.6=h6a678d5_6
34
+ - aws-checksums=0.1.11=h5eee18b_2
35
+ - aws-sdk-cpp=1.8.185=h721c034_1
36
+ - babel=2.11.0=py38h06a4308_0
37
+ - backcall=0.2.0=pyhd3eb1b0_0
38
+ - bcrypt=3.2.0=py38h5eee18b_1
39
+ - beautifulsoup4=4.12.2=py38h06a4308_0
40
+ - binaryornot=0.4.4=pyhd3eb1b0_1
41
+ - blas=1.0=mkl
42
+ - bleach=4.1.0=pyhd3eb1b0_0
43
+ - blosc=1.21.3=h6a678d5_0
44
+ - bokeh=2.4.3=py38h06a4308_0
45
+ - boost-cpp=1.73.0=h7f8727e_12
46
+ - botocore=1.29.76=py38h06a4308_0
47
+ - bottleneck=1.3.5=py38h7deecbd_0
48
+ - brotli=1.0.9=h5eee18b_7
49
+ - brotli-bin=1.0.9=h5eee18b_7
50
+ - brotlipy=0.7.0=py38h27cfd23_1003
51
+ - brunsli=0.1=h2531618_0
52
+ - bzip2=1.0.8=h7b6447c_0
53
+ - c-ares=1.19.1=h5eee18b_0
54
+ - c-blosc2=2.8.0=h6a678d5_0
55
+ - ca-certificates=2023.08.22=h06a4308_0
56
+ - certifi=2023.7.22=py38h06a4308_0
57
+ - cffi=1.15.1=py38h5eee18b_3
58
+ - cfitsio=3.470=h5893167_7
59
+ - chardet=4.0.0=py38h06a4308_1003
60
+ - charls=2.2.0=h2531618_0
61
+ - charset-normalizer=2.0.4=pyhd3eb1b0_0
62
+ - click=8.0.4=py38h06a4308_0
63
+ - cloudpickle=2.2.1=py38h06a4308_0
64
+ - colorama=0.4.6=py38h06a4308_0
65
+ - colorcet=3.0.1=py38h06a4308_0
66
+ - comm=0.1.2=py38h06a4308_0
67
+ - constantly=15.1.0=pyh2b92418_0
68
+ - contourpy=1.0.5=py38hdb19cb5_0
69
+ - cookiecutter=1.7.3=pyhd3eb1b0_0
70
+ - cryptography=41.0.3=py38hdda0065_0
71
+ - cssselect=1.1.0=pyhd3eb1b0_0
72
+ - curl=8.2.1=hdbd6064_0
73
+ - cyrus-sasl=2.1.28=h52b45da_1
74
+ - cytoolz=0.12.0=py38h5eee18b_0
75
+ - daal4py=2023.1.1=py38h79cecc1_0
76
+ - dal=2023.1.1=hdb19cb5_48679
77
+ - dask=2023.4.1=py38h06a4308_1
78
+ - dask-core=2023.4.1=py38h06a4308_0
79
+ - datasets=2.12.0=py38h06a4308_0
80
+ - datashader=0.15.2=py38h06a4308_0
81
+ - datashape=0.5.4=py38h06a4308_1
82
+ - dav1d=1.2.1=h5eee18b_0
83
+ - dbus=1.13.18=hb2f20db_0
84
+ - debugpy=1.6.7=py38h6a678d5_0
85
+ - decorator=5.1.1=pyhd3eb1b0_0
86
+ - defusedxml=0.7.1=pyhd3eb1b0_0
87
+ - diff-match-patch=20200713=pyhd3eb1b0_0
88
+ - dill=0.3.6=py38h06a4308_0
89
+ - distributed=2023.4.1=py38h06a4308_1
90
+ - docstring-to-markdown=0.11=py38h06a4308_0
91
+ - docutils=0.18.1=py38h06a4308_3
92
+ - entrypoints=0.4=py38h06a4308_0
93
+ - et_xmlfile=1.1.0=py38h06a4308_0
94
+ - exceptiongroup=1.0.4=py38h06a4308_0
95
+ - executing=0.8.3=pyhd3eb1b0_0
96
+ - expat=2.5.0=h6a678d5_0
97
+ - filelock=3.9.0=py38h06a4308_0
98
+ - flake8=6.0.0=py38h06a4308_0
99
+ - flask=2.2.2=py38h06a4308_0
100
+ - font-ttf-dejavu-sans-mono=2.37=hd3eb1b0_0
101
+ - font-ttf-inconsolata=2.001=hcb22688_0
102
+ - font-ttf-source-code-pro=2.030=hd3eb1b0_0
103
+ - font-ttf-ubuntu=0.83=h8b1ccd4_0
104
+ - fontconfig=2.14.1=h4c34cd2_2
105
+ - fonts-anaconda=1=h8fa9717_0
106
+ - fonttools=4.25.0=pyhd3eb1b0_0
107
+ - freetype=2.12.1=h4a9f257_0
108
+ - frozenlist=1.3.3=py38h5eee18b_0
109
+ - fsspec=2023.4.0=py38h06a4308_0
110
+ - gensim=4.3.0=py38h6a678d5_0
111
+ - gflags=2.2.2=he6710b0_0
112
+ - giflib=5.2.1=h5eee18b_3
113
+ - glib=2.69.1=he621ea3_2
114
+ - glog=0.5.0=h2531618_0
115
+ - gmp=6.2.1=h295c915_3
116
+ - gmpy2=2.1.2=py38heeb90bb_0
117
+ - greenlet=2.0.1=py38h6a678d5_0
118
+ - grpc-cpp=1.48.2=he1ff14a_1
119
+ - gst-plugins-base=1.14.1=h6a678d5_1
120
+ - gstreamer=1.14.1=h5eee18b_1
121
+ - h5py=3.9.0=py38he06866b_0
122
+ - hdf5=1.12.1=h2b7332f_3
123
+ - heapdict=1.0.1=pyhd3eb1b0_0
124
+ - holoviews=1.17.1=py38h06a4308_0
125
+ - huggingface_hub=0.15.1=py38h06a4308_0
126
+ - hvplot=0.8.4=py38h06a4308_0
127
+ - hyperlink=21.0.0=pyhd3eb1b0_0
128
+ - icu=58.2=he6710b0_3
129
+ - imagecodecs=2023.1.23=py38hc4b7b5f_0
130
+ - imageio=2.31.1=py38h06a4308_0
131
+ - imagesize=1.4.1=py38h06a4308_0
132
+ - imbalanced-learn=0.10.1=py38h06a4308_1
133
+ - importlib-metadata=6.0.0=py38h06a4308_0
134
+ - importlib_metadata=6.0.0=hd3eb1b0_0
135
+ - importlib_resources=5.2.0=pyhd3eb1b0_1
136
+ - incremental=21.3.0=pyhd3eb1b0_0
137
+ - inflection=0.5.1=py38h06a4308_0
138
+ - iniconfig=1.1.1=pyhd3eb1b0_0
139
+ - intake=0.6.8=py38h06a4308_0
140
+ - intel-openmp=2023.1.0=hdb19cb5_46305
141
+ - intervaltree=3.1.0=pyhd3eb1b0_0
142
+ - ipykernel=6.25.0=py38h2f386ee_0
143
+ - ipython=8.12.2=py38h06a4308_0
144
+ - ipython_genutils=0.2.0=pyhd3eb1b0_1
145
+ - ipywidgets=8.0.4=py38h06a4308_0
146
+ - isort=5.9.3=pyhd3eb1b0_0
147
+ - itemadapter=0.3.0=pyhd3eb1b0_0
148
+ - itemloaders=1.0.4=pyhd3eb1b0_1
149
+ - itsdangerous=2.0.1=pyhd3eb1b0_0
150
+ - jaraco.classes=3.2.1=pyhd3eb1b0_0
151
+ - jedi=0.18.1=py38h06a4308_1
152
+ - jeepney=0.7.1=pyhd3eb1b0_0
153
+ - jellyfish=1.0.1=py38hb02cf49_0
154
+ - jinja2=3.1.2=py38h06a4308_0
155
+ - jinja2-time=0.2.0=pyhd3eb1b0_3
156
+ - jmespath=0.10.0=pyhd3eb1b0_0
157
+ - joblib=1.2.0=py38h06a4308_0
158
+ - jpeg=9e=h5eee18b_1
159
+ - jq=1.6=h27cfd23_1000
160
+ - json5=0.9.6=pyhd3eb1b0_0
161
+ - jsonschema=4.17.3=py38h06a4308_0
162
+ - jupyter=1.0.0=py38h06a4308_8
163
+ - jupyter_client=7.4.9=py38h06a4308_0
164
+ - jupyter_console=6.6.3=py38h06a4308_0
165
+ - jupyter_core=5.3.0=py38h06a4308_0
166
+ - jupyter_events=0.6.3=py38h06a4308_0
167
+ - jupyter_server=1.23.4=py38h06a4308_0
168
+ - jupyter_server_fileid=0.9.0=py38h06a4308_0
169
+ - jupyter_server_ydoc=0.8.0=py38h06a4308_1
170
+ - jupyter_ydoc=0.2.4=py38h06a4308_0
171
+ - jupyterlab=3.6.3=py38h06a4308_0
172
+ - jupyterlab_pygments=0.1.2=py_0
173
+ - jupyterlab_server=2.22.0=py38h06a4308_0
174
+ - jupyterlab_widgets=3.0.5=py38h06a4308_0
175
+ - jxrlib=1.1=h7b6447c_2
176
+ - kaleido-core=0.2.1=h7c8854e_0
177
+ - keyring=23.13.1=py38h06a4308_0
178
+ - kiwisolver=1.4.4=py38h6a678d5_0
179
+ - krb5=1.20.1=h143b758_1
180
+ - lazy-object-proxy=1.6.0=py38h27cfd23_0
181
+ - lcms2=2.12=h3be6417_0
182
+ - ld_impl_linux-64=2.38=h1181459_1
183
+ - lerc=3.0=h295c915_0
184
+ - libaec=1.0.4=he6710b0_1
185
+ - libavif=0.11.1=h5eee18b_0
186
+ - libboost=1.73.0=h28710b8_12
187
+ - libbrotlicommon=1.0.9=h5eee18b_7
188
+ - libbrotlidec=1.0.9=h5eee18b_7
189
+ - libbrotlienc=1.0.9=h5eee18b_7
190
+ - libclang=14.0.6=default_hc6dbbc7_1
191
+ - libclang13=14.0.6=default_he11475f_1
192
+ - libcups=2.4.2=h2d74bed_1
193
+ - libcurl=8.2.1=h251f7ec_0
194
+ - libdeflate=1.17=h5eee18b_0
195
+ - libedit=3.1.20221030=h5eee18b_0
196
+ - libev=4.33=h7f8727e_1
197
+ - libevent=2.1.12=hdbd6064_1
198
+ - libffi=3.4.4=h6a678d5_0
199
+ - libgcc-ng=11.2.0=h1234567_1
200
+ - libgfortran-ng=11.2.0=h00389a5_1
201
+ - libgfortran5=11.2.0=h1234567_1
202
+ - libgomp=11.2.0=h1234567_1
203
+ - libllvm14=14.0.6=hdb19cb5_3
204
+ - libnghttp2=1.52.0=h2d74bed_1
205
+ - libpng=1.6.39=h5eee18b_0
206
+ - libpq=12.15=hdbd6064_1
207
+ - libprotobuf=3.20.3=he621ea3_0
208
+ - libsodium=1.0.18=h7b6447c_0
209
+ - libspatialindex=1.9.3=h2531618_0
210
+ - libssh2=1.10.0=hdbd6064_2
211
+ - libstdcxx-ng=11.2.0=h1234567_1
212
+ - libthrift=0.15.0=h1795dd8_2
213
+ - libtiff=4.5.1=h6a678d5_0
214
+ - libuuid=1.41.5=h5eee18b_0
215
+ - libwebp=1.3.2=h11a3e52_0
216
+ - libwebp-base=1.3.2=h5eee18b_0
217
+ - libxcb=1.15=h7f8727e_0
218
+ - libxkbcommon=1.0.1=h5eee18b_1
219
+ - libxml2=2.10.4=hcbfbd50_0
220
+ - libxslt=1.1.37=h2085143_0
221
+ - libzopfli=1.0.3=he6710b0_0
222
+ - llvmlite=0.40.0=py38he621ea3_0
223
+ - locket=1.0.0=py38h06a4308_0
224
+ - lxml=4.9.3=py38hdbbb534_0
225
+ - lz4-c=1.9.4=h6a678d5_0
226
+ - lzo=2.10=h7b6447c_2
227
+ - markdown=3.4.1=py38h06a4308_0
228
+ - markupsafe=2.1.1=py38h7f8727e_0
229
+ - mathjax=2.7.5=h06a4308_0
230
+ - matplotlib=3.7.2=py38h06a4308_0
231
+ - matplotlib-base=3.7.2=py38h1128e8f_0
232
+ - matplotlib-inline=0.1.6=py38h06a4308_0
233
+ - mccabe=0.7.0=pyhd3eb1b0_0
234
+ - mistune=0.8.4=py38h7b6447c_1000
235
+ - mkl=2023.1.0=h213fc3f_46343
236
+ - mkl-service=2.4.0=py38h5eee18b_1
237
+ - mkl_fft=1.3.8=py38h5eee18b_0
238
+ - mkl_random=1.2.4=py38hdb19cb5_0
239
+ - more-itertools=8.12.0=pyhd3eb1b0_0
240
+ - mpc=1.1.0=h10f8cd9_1
241
+ - mpfr=4.0.2=hb69a4c5_1
242
+ - mpi=1.0=mpich
243
+ - mpich=4.1.1=hbae89fd_0
244
+ - mpmath=1.3.0=py38h06a4308_0
245
+ - msgpack-python=1.0.3=py38hd09550d_0
246
+ - multidict=6.0.2=py38h5eee18b_0
247
+ - multipledispatch=0.6.0=py38_0
248
+ - multiprocess=0.70.14=py38h06a4308_0
249
+ - munkres=1.1.4=py_0
250
+ - mypy_extensions=1.0.0=py38h06a4308_0
251
+ - mysql=5.7.24=h721c034_2
252
+ - nbclassic=0.5.5=py38h06a4308_0
253
+ - nbclient=0.5.13=py38h06a4308_0
254
+ - nbconvert=6.5.4=py38h06a4308_0
255
+ - nbformat=5.9.2=py38h06a4308_0
256
+ - ncurses=6.4=h6a678d5_0
257
+ - nest-asyncio=1.5.6=py38h06a4308_0
258
+ - networkx=3.1=py38h06a4308_0
259
+ - nltk=3.8.1=py38h06a4308_0
260
+ - notebook=6.5.4=py38h06a4308_1
261
+ - notebook-shim=0.2.2=py38h06a4308_0
262
+ - nspr=4.35=h6a678d5_0
263
+ - nss=3.89.1=h6a678d5_0
264
+ - numba=0.57.1=py38h1128e8f_0
265
+ - numexpr=2.8.4=py38hc78ab66_1
266
+ - numpy=1.24.3=py38hf6e8229_1
267
+ - numpy-base=1.24.3=py38h060ed82_1
268
+ - numpydoc=1.5.0=py38h06a4308_0
269
+ - oniguruma=6.9.7.1=h27cfd23_0
270
+ - openjpeg=2.4.0=h3ad879b_0
271
+ - openpyxl=3.0.10=py38h5eee18b_0
272
+ - openssl=3.0.10=h7f8727e_2
273
+ - orc=1.7.4=hb3bc3d3_1
274
+ - packaging=23.1=py38h06a4308_0
275
+ - pandas=2.0.3=py38h1128e8f_0
276
+ - pandocfilters=1.5.0=pyhd3eb1b0_0
277
+ - panel=0.14.3=py38h06a4308_0
278
+ - param=1.13.0=py38h06a4308_0
279
+ - parsel=1.6.0=py38h06a4308_0
280
+ - parso=0.8.3=pyhd3eb1b0_0
281
+ - partd=1.4.0=py38h06a4308_0
282
+ - pathspec=0.10.3=py38h06a4308_0
283
+ - patsy=0.5.3=py38h06a4308_0
284
+ - pcre=8.45=h295c915_0
285
+ - pep8=1.7.1=py38h06a4308_1
286
+ - pexpect=4.8.0=pyhd3eb1b0_3
287
+ - pickleshare=0.7.5=pyhd3eb1b0_1003
288
+ - pip=23.2.1=py38h06a4308_0
289
+ - pkgutil-resolve-name=1.3.10=py38h06a4308_0
290
+ - platformdirs=3.10.0=py38h06a4308_0
291
+ - plotly=5.9.0=py38h06a4308_0
292
+ - pluggy=1.0.0=py38h06a4308_1
293
+ - ply=3.11=py38_0
294
+ - pooch=1.4.0=pyhd3eb1b0_0
295
+ - poyo=0.5.0=pyhd3eb1b0_0
296
+ - prometheus_client=0.14.1=py38h06a4308_0
297
+ - prompt-toolkit=3.0.36=py38h06a4308_0
298
+ - prompt_toolkit=3.0.36=hd3eb1b0_0
299
+ - protego=0.1.16=py_0
300
+ - psutil=5.9.0=py38h5eee18b_0
301
+ - ptyprocess=0.7.0=pyhd3eb1b0_2
302
+ - pure_eval=0.2.2=pyhd3eb1b0_0
303
+ - py-cpuinfo=8.0.0=pyhd3eb1b0_1
304
+ - pyarrow=11.0.0=py38h468efa6_1
305
+ - pyasn1=0.4.8=pyhd3eb1b0_0
306
+ - pyasn1-modules=0.2.8=py_0
307
+ - pycodestyle=2.10.0=py38h06a4308_0
308
+ - pycparser=2.21=pyhd3eb1b0_0
309
+ - pyct=0.5.0=py38h06a4308_0
310
+ - pycurl=7.45.2=py38hdbd6064_1
311
+ - pydispatcher=2.0.5=py38h06a4308_2
312
+ - pydocstyle=6.3.0=py38h06a4308_0
313
+ - pyerfa=2.0.0=py38h27cfd23_0
314
+ - pyflakes=3.0.1=py38h06a4308_0
315
+ - pygments=2.15.1=py38h06a4308_1
316
+ - pylint=2.16.2=py38h06a4308_0
317
+ - pylint-venv=2.3.0=py38h06a4308_0
318
+ - pyls-spyder=0.4.0=pyhd3eb1b0_0
319
+ - pyodbc=4.0.34=py38h6a678d5_0
320
+ - pyopenssl=23.2.0=py38h06a4308_0
321
+ - pyqt=5.15.7=py38h6a678d5_1
322
+ - pyqt5-sip=12.11.0=py38h6a678d5_1
323
+ - pyqtwebengine=5.15.7=py38h6a678d5_1
324
+ - pyrsistent=0.18.0=py38heee7806_0
325
+ - pysocks=1.7.1=py38h06a4308_0
326
+ - pytables=3.8.0=py38hb8ae3fc_3
327
+ - pytest=7.4.0=py38h06a4308_0
328
+ - python=3.8.18=h955ad1f_0
329
+ - python-dateutil=2.8.2=pyhd3eb1b0_0
330
+ - python-fastjsonschema=2.16.2=py38h06a4308_0
331
+ - python-json-logger=2.0.7=py38h06a4308_0
332
+ - python-kaleido=0.2.1=py38h06a4308_0
333
+ - python-lmdb=1.4.1=py38h6a678d5_0
334
+ - python-lsp-black=1.2.1=py38h06a4308_0
335
+ - python-lsp-jsonrpc=1.0.0=pyhd3eb1b0_0
336
+ - python-lsp-server=1.7.2=py38h06a4308_0
337
+ - python-slugify=5.0.2=pyhd3eb1b0_0
338
+ - python-snappy=0.6.1=py38h6a678d5_0
339
+ - python-tzdata=2023.3=pyhd3eb1b0_0
340
+ - python-xxhash=2.0.2=py38h5eee18b_1
341
+ - pytoolconfig=1.2.5=py38h06a4308_1
342
+ - pytz=2023.3.post1=py38h06a4308_0
343
+ - pyviz_comms=2.3.0=py38h06a4308_0
344
+ - pywavelets=1.4.1=py38h5eee18b_0
345
+ - pyxdg=0.27=pyhd3eb1b0_0
346
+ - pyyaml=6.0=py38h5eee18b_1
347
+ - pyzmq=23.2.0=py38h6a678d5_0
348
+ - qdarkstyle=3.0.2=pyhd3eb1b0_0
349
+ - qstylizer=0.2.2=py38h06a4308_0
350
+ - qt-main=5.15.2=h7358343_9
351
+ - qt-webengine=5.15.9=h9ab4d14_7
352
+ - qtawesome=1.2.2=py38h06a4308_0
353
+ - qtconsole=5.4.2=py38h06a4308_0
354
+ - qtpy=2.2.0=py38h06a4308_0
355
+ - qtwebkit=5.212=h3fafdc1_5
356
+ - queuelib=1.5.0=py38h06a4308_0
357
+ - re2=2022.04.01=h295c915_0
358
+ - readline=8.2=h5eee18b_0
359
+ - regex=2022.7.9=py38h5eee18b_0
360
+ - requests=2.31.0=py38h06a4308_0
361
+ - requests-file=1.5.1=pyhd3eb1b0_0
362
+ - responses=0.13.3=pyhd3eb1b0_0
363
+ - rfc3339-validator=0.1.4=py38h06a4308_0
364
+ - rfc3986-validator=0.1.1=py38h06a4308_0
365
+ - rope=1.7.0=py38h06a4308_0
366
+ - rtree=1.0.1=py38h06a4308_0
367
+ - s3fs=2023.4.0=py38h06a4308_0
368
+ - safetensors=0.3.2=py38hb02cf49_0
369
+ - scikit-image=0.19.3=py38h6a678d5_1
370
+ - scikit-learn=1.3.0=py38h1128e8f_0
371
+ - scikit-learn-intelex=2023.1.1=py38h06a4308_0
372
+ - scipy=1.10.1=py38hf6e8229_1
373
+ - scrapy=2.8.0=py38h06a4308_0
374
+ - seaborn=0.12.2=py38h06a4308_0
375
+ - secretstorage=3.3.1=py38h06a4308_1
376
+ - send2trash=1.8.0=pyhd3eb1b0_1
377
+ - service_identity=18.1.0=pyhd3eb1b0_1
378
+ - setuptools=68.0.0=py38h06a4308_0
379
+ - sip=6.6.2=py38h6a678d5_0
380
+ - six=1.16.0=pyhd3eb1b0_1
381
+ - smart_open=5.2.1=py38h06a4308_0
382
+ - snappy=1.1.9=h295c915_0
383
+ - sniffio=1.2.0=py38h06a4308_1
384
+ - snowballstemmer=2.2.0=pyhd3eb1b0_0
385
+ - sortedcontainers=2.4.0=pyhd3eb1b0_0
386
+ - soupsieve=2.4=py38h06a4308_0
387
+ - sphinx=5.0.2=py38h06a4308_0
388
+ - sphinxcontrib-applehelp=1.0.2=pyhd3eb1b0_0
389
+ - sphinxcontrib-devhelp=1.0.2=pyhd3eb1b0_0
390
+ - sphinxcontrib-htmlhelp=2.0.0=pyhd3eb1b0_0
391
+ - sphinxcontrib-jsmath=1.0.1=pyhd3eb1b0_0
392
+ - sphinxcontrib-qthelp=1.0.3=pyhd3eb1b0_0
393
+ - sphinxcontrib-serializinghtml=1.1.5=pyhd3eb1b0_0
394
+ - spyder=5.4.3=py38h06a4308_1
395
+ - spyder-kernels=2.4.4=py38h06a4308_0
396
+ - sqlalchemy=1.4.39=py38h5eee18b_0
397
+ - sqlite=3.41.2=h5eee18b_0
398
+ - stack_data=0.2.0=pyhd3eb1b0_0
399
+ - statsmodels=0.14.0=py38ha9d4c09_0
400
+ - sympy=1.11.1=py38h06a4308_0
401
+ - tabulate=0.8.10=py38h06a4308_0
402
+ - tbb=2021.8.0=hdb19cb5_0
403
+ - tbb4py=2021.8.0=py38hdb19cb5_0
404
+ - tblib=1.7.0=pyhd3eb1b0_0
405
+ - tenacity=8.2.2=py38h06a4308_0
406
+ - terminado=0.17.1=py38h06a4308_0
407
+ - text-unidecode=1.3=pyhd3eb1b0_0
408
+ - textdistance=4.2.1=pyhd3eb1b0_0
409
+ - threadpoolctl=2.2.0=pyh0d69192_0
410
+ - three-merge=0.1.1=pyhd3eb1b0_0
411
+ - tifffile=2023.4.12=py38h06a4308_0
412
+ - tinycss2=1.2.1=py38h06a4308_0
413
+ - tk=8.6.12=h1ccaba5_0
414
+ - tldextract=3.2.0=pyhd3eb1b0_0
415
+ - toml=0.10.2=pyhd3eb1b0_0
416
+ - tomli=2.0.1=py38h06a4308_0
417
+ - tomlkit=0.11.1=py38h06a4308_0
418
+ - toolz=0.12.0=py38h06a4308_0
419
+ - tornado=6.3.2=py38h5eee18b_0
420
+ - tqdm=4.65.0=py38hb070fc8_0
421
+ - traitlets=5.7.1=py38h06a4308_0
422
+ - twisted=22.10.0=py38h5eee18b_0
423
+ - typing_extensions=4.7.1=py38h06a4308_0
424
+ - ujson=5.4.0=py38h6a678d5_0
425
+ - unidecode=1.2.0=pyhd3eb1b0_0
426
+ - unixodbc=2.3.11=h5eee18b_0
427
+ - urllib3=1.26.16=py38h06a4308_0
428
+ - utf8proc=2.6.1=h27cfd23_0
429
+ - w3lib=1.21.0=pyhd3eb1b0_0
430
+ - watchdog=2.1.6=py38h06a4308_0
431
+ - wcwidth=0.2.5=pyhd3eb1b0_0
432
+ - webencodings=0.5.1=py38_1
433
+ - websocket-client=0.58.0=py38h06a4308_4
434
+ - werkzeug=2.2.3=py38h06a4308_0
435
+ - whatthepatch=1.0.2=py38h06a4308_0
436
+ - wheel=0.38.4=py38h06a4308_0
437
+ - widgetsnbextension=4.0.5=py38h06a4308_0
438
+ - wrapt=1.14.1=py38h5eee18b_0
439
+ - wurlitzer=3.0.2=py38h06a4308_0
440
+ - xarray=2022.11.0=py38h06a4308_0
441
+ - xxhash=0.8.0=h7f8727e_3
442
+ - xz=5.4.2=h5eee18b_0
443
+ - y-py=0.5.9=py38h52d8a92_0
444
+ - yaml=0.2.5=h7b6447c_0
445
+ - yapf=0.31.0=pyhd3eb1b0_0
446
+ - yarl=1.8.1=py38h5eee18b_0
447
+ - ypy-websocket=0.8.2=py38h06a4308_0
448
+ - zeromq=4.3.4=h2531618_0
449
+ - zfp=1.0.0=h6a678d5_0
450
+ - zict=2.2.0=py38h06a4308_0
451
+ - zipp=3.11.0=py38h06a4308_0
452
+ - zlib=1.2.13=h5eee18b_0
453
+ - zlib-ng=2.0.7=h5eee18b_0
454
+ - zope=1.0=py38_1
455
+ - zope.interface=5.4.0=py38h7f8727e_0
456
+ - zstd=1.5.5=hc292b87_0
457
+ - pip:
458
+ - absl-py==2.0.0
459
+ - addict==2.4.0
460
+ - altair==5.1.2
461
+ - annotated-types==0.6.0
462
+ - antlr4-python3-runtime==4.9.3
463
+ - anyio==3.7.1
464
+ - autodistill==0.1.16
465
+ - autodistill-detic==0.1.4
466
+ - autodistill-fastsam==0.1.0
467
+ - autodistill-grounded-sam==0.1.1
468
+ - autodistill-grounding-dino==0.1.2
469
+ - autodistill-llava==0.1.0
470
+ - autodistill-metaclip==0.1.1
471
+ - autodistill-owl-vit==0.1.1
472
+ - autodistill-owlv2==0.1.0
473
+ - autodistill-sam-clip==0.1.3
474
+ - autodistill-seggpt==0.1.6
475
+ - autodistill-yolov8==0.1.2
476
+ - black==22.3.0
477
+ - cachetools==5.3.1
478
+ - clip==1.0
479
+ - cmake==3.27.5
480
+ - combinadics==0.0.3
481
+ - cycler==0.10.0
482
+ - cython==3.0.4
483
+ - dataclasses==0.6
484
+ - detectron2-layers==0.0.5
485
+ - einops==0.7.0
486
+ - einops-exts==0.0.4
487
+ - fairscale==0.4.13
488
+ - fastapi==0.104.1
489
+ - fasttext==0.9.2
490
+ - ffmpy==0.3.1
491
+ - ftfy==6.1.1
492
+ - future==0.18.3
493
+ - fvcore==0.1.5.post20221221
494
+ - google-auth==2.23.0
495
+ - google-auth-oauthlib==1.0.0
496
+ - gradio==3.35.2
497
+ - gradio-client==0.7.0
498
+ - grpcio==1.58.0
499
+ - h11==0.14.0
500
+ - httpcore==0.18.0
501
+ - httpx==0.25.0
502
+ - huggingface-hub==0.17.3
503
+ - hydra-core==1.3.2
504
+ - idna==2.10
505
+ - iopath==0.1.9
506
+ - linkify-it-py==2.0.2
507
+ - lit==16.0.6
508
+ - llava==0.0.1.dev0
509
+ - lvis==0.5.3
510
+ - markdown-it-py==2.2.0
511
+ - mdit-py-plugins==0.3.3
512
+ - mdurl==0.1.2
513
+ - mss==9.0.1
514
+ - natsort==8.4.0
515
+ - nvidia-cublas-cu11==11.10.3.66
516
+ - nvidia-cuda-cupti-cu11==11.7.101
517
+ - nvidia-cuda-nvrtc-cu11==11.7.99
518
+ - nvidia-cuda-runtime-cu11==11.7.99
519
+ - nvidia-cudnn-cu11==8.5.0.96
520
+ - nvidia-cufft-cu11==10.9.0.58
521
+ - nvidia-curand-cu11==10.2.10.91
522
+ - nvidia-cusolver-cu11==11.4.0.1
523
+ - nvidia-cusparse-cu11==11.7.4.91
524
+ - nvidia-nccl-cu11==2.14.3
525
+ - nvidia-nvtx-cu11==11.7.91
526
+ - oauthlib==3.2.2
527
+ - omegaconf==2.3.0
528
+ - onnx==1.14.1
529
+ - onnx-simplifier==0.4.33
530
+ - open-clip-torch==2.23.0
531
+ - open-flamingo==2.0.1
532
+ - opencv-python==4.8.0.76
533
+ - opencv-python-headless==4.8.0.74
534
+ - orjson==3.9.10
535
+ - pillow==8.4.0
536
+ - portalocker==2.8.2
537
+ - protobuf==4.24.3
538
+ - pybind11==2.11.1
539
+ - pycocotools==2.0.7
540
+ - pydantic==2.4.2
541
+ - pydantic-core==2.10.1
542
+ - pydot==1.4.2
543
+ - pydub==0.25.1
544
+ - pyparsing==2.4.7
545
+ - python-dotenv==1.0.0
546
+ - python-magic==0.4.27
547
+ - python-multipart==0.0.6
548
+ - requests-oauthlib==1.3.1
549
+ - requests-toolbelt==1.0.0
550
+ - rf-groundingdino==0.1.2
551
+ - rf-segment-anything==1.0
552
+ - rich==13.5.3
553
+ - roboflow==1.1.9
554
+ - rsa==4.9
555
+ - semantic-version==2.10.0
556
+ - sentencepiece==0.1.98
557
+ - sentry-sdk==1.34.0
558
+ - starlette==0.27.0
559
+ - supervision==0.9.0
560
+ - tensorboard==2.14.0
561
+ - tensorboard-data-server==0.7.1
562
+ - termcolor==2.3.0
563
+ - thop==0.1.1-2209072238
564
+ - timm==0.9.8
565
+ - tokenizers==0.14.1
566
+ - torch==2.0.1
567
+ - torchvision==0.15.2
568
+ - transformers==4.35.0.dev0
569
+ - triton==2.0.0
570
+ - typing-extensions==4.8.0
571
+ - uc-micro-py==1.0.2
572
+ - ultralytics==8.0.81
573
+ - uvicorn==0.23.2
574
+ - websockets==11.0.3
575
+ - yacs==0.1.8
environment_initial.yml ADDED
@@ -0,0 +1,139 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Environment used to reporduce results
2
+ name: peekaboo
3
+ channels:
4
+ - defaults
5
+ - conda-forge
6
+ dependencies:
7
+ - _libgcc_mutex=0.1=main
8
+ - _openmp_mutex=5.1=1_gnu
9
+ - asttokens=2.4.0=pyhd8ed1ab_0
10
+ - backcall=0.2.0=pyh9f0ad1d_0
11
+ - backports=1.0=pyhd8ed1ab_3
12
+ - backports.functools_lru_cache=1.6.5=pyhd8ed1ab_0
13
+ - ca-certificates=2023.7.22=hbcca054_0
14
+ - comm=0.1.4=pyhd8ed1ab_0
15
+ - debugpy=1.6.7=py38h6a678d5_0
16
+ - entrypoints=0.4=pyhd8ed1ab_0
17
+ - executing=1.2.0=pyhd8ed1ab_0
18
+ - ipykernel=6.25.2=pyh2140261_0
19
+ - ipython=8.12.0=pyh41d4057_0
20
+ - jedi=0.19.1=pyhd8ed1ab_0
21
+ - jupyter_client=7.3.4=pyhd8ed1ab_0
22
+ - jupyter_core=5.4.0=py38h578d9bd_0
23
+ - ld_impl_linux-64=2.38=h1181459_1
24
+ - libffi=3.4.4=h6a678d5_0
25
+ - libgcc-ng=11.2.0=h1234567_1
26
+ - libgomp=11.2.0=h1234567_1
27
+ - libsodium=1.0.18=h36c2ea0_1
28
+ - libstdcxx-ng=11.2.0=h1234567_1
29
+ - matplotlib-inline=0.1.6=pyhd8ed1ab_0
30
+ - ncurses=6.4=h6a678d5_0
31
+ - nest-asyncio=1.5.8=pyhd8ed1ab_0
32
+ - openssl=3.0.11=h7f8727e_2
33
+ - packaging=23.2=pyhd8ed1ab_0
34
+ - parso=0.8.3=pyhd8ed1ab_0
35
+ - pexpect=4.8.0=pyh1a96a4e_2
36
+ - pickleshare=0.7.5=py_1003
37
+ - pip=23.1.2=py38h06a4308_0
38
+ - platformdirs=3.11.0=pyhd8ed1ab_0
39
+ - prompt-toolkit=3.0.39=pyha770c72_0
40
+ - prompt_toolkit=3.0.39=hd8ed1ab_0
41
+ - psutil=5.9.0=py38h5eee18b_0
42
+ - ptyprocess=0.7.0=pyhd3deb0d_0
43
+ - pure_eval=0.2.2=pyhd8ed1ab_0
44
+ - python=3.8.16=h955ad1f_4
45
+ - python-dateutil=2.8.2=pyhd8ed1ab_0
46
+ - python_abi=3.8=2_cp38
47
+ - pyzmq=25.1.0=py38h6a678d5_0
48
+ - readline=8.2=h5eee18b_0
49
+ - setuptools=67.8.0=py38h06a4308_0
50
+ - six=1.16.0=pyh6c4a22f_0
51
+ - sqlite=3.41.2=h5eee18b_0
52
+ - stack_data=0.6.2=pyhd8ed1ab_0
53
+ - tk=8.6.12=h1ccaba5_0
54
+ - tornado=6.1=py38h0a891b7_3
55
+ - traitlets=5.11.2=pyhd8ed1ab_0
56
+ - typing_extensions=4.8.0=pyha770c72_0
57
+ - wcwidth=0.2.8=pyhd8ed1ab_0
58
+ - wheel=0.38.4=py38h06a4308_0
59
+ - xz=5.4.2=h5eee18b_0
60
+ - zeromq=4.3.4=h2531618_0
61
+ - zlib=1.2.13=h5eee18b_0
62
+ - pip:
63
+ - absl-py==1.4.0
64
+ - addict==2.4.0
65
+ - cachetools==5.3.1
66
+ - certifi==2023.5.7
67
+ - charset-normalizer==3.1.0
68
+ - cmake==3.26.4
69
+ - decorator==4.4.2
70
+ - filelock==3.12.2
71
+ - fonttools==4.41.0
72
+ - google-auth==2.22.0
73
+ - google-auth-oauthlib==1.0.0
74
+ - grpcio==1.56.0
75
+ - idna==3.4
76
+ - imageio==2.31.1
77
+ - imageio-ffmpeg==0.4.8
78
+ - importlib-metadata==6.8.0
79
+ - jinja2==3.1.2
80
+ - kiwisolver==1.4.4
81
+ - labelimg==1.8.6
82
+ - lazy-loader==0.3
83
+ - lit==16.0.6
84
+ - lxml==4.9.2
85
+ - markdown==3.4.3
86
+ - markdown-it-py==3.0.0
87
+ - markupsafe==2.1.3
88
+ - matplotlib==3.7.2
89
+ - mdurl==0.1.2
90
+ - moviepy==1.0.3
91
+ - mpmath==1.3.0
92
+ - networkx==3.1
93
+ - numpy==1.24.4
94
+ - nvidia-cublas-cu11==11.10.3.66
95
+ - nvidia-cuda-cupti-cu11==11.7.101
96
+ - nvidia-cuda-nvrtc-cu11==11.7.99
97
+ - nvidia-cuda-runtime-cu11==11.7.99
98
+ - nvidia-cudnn-cu11==8.5.0.96
99
+ - nvidia-cufft-cu11==10.9.0.58
100
+ - nvidia-curand-cu11==10.2.10.91
101
+ - nvidia-cusolver-cu11==11.4.0.1
102
+ - nvidia-cusparse-cu11==11.7.4.91
103
+ - nvidia-nccl-cu11==2.14.3
104
+ - nvidia-nvtx-cu11==11.7.91
105
+ - oauthlib==3.2.2
106
+ - onnx==1.14.0
107
+ - onnx-simplifier==0.4.33
108
+ - opencv-python==4.5.5.64
109
+ - opencv-python-headless==4.5.5.64
110
+ - pillow==9.5.0
111
+ - proglog==0.1.10
112
+ - protobuf==4.23.4
113
+ - pyasn1==0.5.0
114
+ - pyasn1-modules==0.3.0
115
+ - pycocotools==2.0.6
116
+ - pygments==2.15.1
117
+ - pyqt5==5.15.9
118
+ - pyqt5-qt5==5.15.2
119
+ - pyqt5-sip==12.12.1
120
+ - pywavelets==1.4.1
121
+ - pyyaml==6.0
122
+ - requests==2.31.0
123
+ - requests-oauthlib==1.3.1
124
+ - rich==13.4.2
125
+ - rsa==4.9
126
+ - scikit-image==0.21.0
127
+ - scipy==1.10.1
128
+ - sympy==1.12
129
+ - tensorboard==2.13.0
130
+ - tensorboard-data-server==0.7.1
131
+ - thop==0.1.1-2209072238
132
+ - tifffile==2023.7.10
133
+ - torch==2.0.1
134
+ - torchvision==0.15.2
135
+ - tqdm==4.65.0
136
+ - triton==2.0.0
137
+ - typing-extensions==4.7.1
138
+ - urllib3==1.26.16
139
+ - werkzeug==2.3.6
evaluate.py ADDED
@@ -0,0 +1,110 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2022 - Valeo Comfort and Driving Assistance - Oriane Siméoni @ valeo.ai
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import argparse
16
+ from model import PeekabooModel
17
+ from misc import load_config
18
+ from datasets.datasets import build_dataset
19
+ from evaluation.saliency import evaluate_saliency
20
+ from evaluation.uod import evaluation_unsupervised_object_discovery
21
+
22
+ if __name__ == "__main__":
23
+ parser = argparse.ArgumentParser(
24
+ description="Evaluation of Peekaboo",
25
+ formatter_class=argparse.ArgumentDefaultsHelpFormatter,
26
+ )
27
+ parser.add_argument(
28
+ "--eval-type", type=str, choices=["saliency", "uod"], help="Evaluation type."
29
+ )
30
+ parser.add_argument(
31
+ "--dataset-eval",
32
+ type=str,
33
+ choices=["ECSSD", "DUT-OMRON", "DUTS-TEST", "VOC07", "VOC12", "COCO20k"],
34
+ help="Name of evaluation dataset.",
35
+ )
36
+ parser.add_argument(
37
+ "--dataset-set-eval", type=str, default=None, help="Set of the dataset."
38
+ )
39
+ parser.add_argument(
40
+ "--apply-bilateral", action="store_true", help="use bilateral solver."
41
+ )
42
+ parser.add_argument(
43
+ "--evaluation-mode",
44
+ type=str,
45
+ default="multi",
46
+ choices=["single", "multi"],
47
+ help="Type of evaluation.",
48
+ )
49
+ parser.add_argument(
50
+ "--model-weights",
51
+ type=str,
52
+ default="data/weights/decoder_weights.pt",
53
+ )
54
+ parser.add_argument(
55
+ "--dataset-dir",
56
+ type=str,
57
+ )
58
+ parser.add_argument(
59
+ "--config",
60
+ type=str,
61
+ default="configs/peekaboo_DUTS-TR.yaml",
62
+ )
63
+ args = parser.parse_args()
64
+ print(args.__dict__)
65
+
66
+ # Configuration
67
+ config, _ = load_config(args.config)
68
+
69
+ # Load the model
70
+ model = PeekabooModel(
71
+ vit_model=config.model["pre_training"],
72
+ vit_arch=config.model["arch"],
73
+ vit_patch_size=config.model["patch_size"],
74
+ enc_type_feats=config.peekaboo["feats"],
75
+ )
76
+ # Load weights
77
+ model.decoder_load_weights(args.model_weights)
78
+ model.eval()
79
+ print(f"Model {args.model_weights} loaded correctly.")
80
+
81
+ # Build the validation set
82
+ val_dataset = build_dataset(
83
+ root_dir=args.dataset_dir,
84
+ dataset_name=args.dataset_eval,
85
+ dataset_set=args.dataset_set_eval,
86
+ for_eval=True,
87
+ evaluation_type=args.eval_type,
88
+ )
89
+ print(f"\nBuilding dataset {val_dataset.name} (#{len(val_dataset)} images)")
90
+
91
+ # Validation
92
+ print(f"\nStarted evaluation on {val_dataset.name}")
93
+ if args.eval_type == "saliency":
94
+ evaluate_saliency(
95
+ val_dataset,
96
+ model=model,
97
+ evaluation_mode=args.evaluation_mode,
98
+ apply_bilateral=args.apply_bilateral,
99
+ )
100
+ elif args.eval_type == "uod":
101
+ if args.apply_bilateral:
102
+ raise ValueError("Not implemented.")
103
+
104
+ evaluation_unsupervised_object_discovery(
105
+ val_dataset,
106
+ model=model,
107
+ evaluation_mode=args.evaluation_mode,
108
+ )
109
+ else:
110
+ raise ValueError("Other evaluation method to come.")
evaluate_saliency.sh ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ MODEL=$1
2
+ DATASET_DIR=$2
3
+ MODE=$3
4
+
5
+ # Unsupervised saliency detection evaluation
6
+ for DATASET in ECSSD DUTS-TEST DUT-OMRON
7
+ do
8
+ python evaluate.py --eval-type saliency --dataset-eval $DATASET \
9
+ --model-weights $MODEL --evaluation-mode $MODE --apply-bilateral --dataset-dir $DATASET_DIR
10
+ done
11
+
12
+
evaluate_uod.sh ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ MODEL=$1
2
+ DATASET_DIR=$2
3
+
4
+ # Single object discovery evaluation
5
+ for DATASET in VOC07 VOC12 COCO20k
6
+ do
7
+ python evaluate.py --eval-type uod --dataset-eval $DATASET \
8
+ --model-weights $MODEL --evaluation-mode single --dataset-dir $DATASET_DIR
9
+ done
10
+
11
+
evaluation/__init__.py ADDED
File without changes
evaluation/metrics/__init__.py ADDED
File without changes
evaluation/metrics/average_meter.py ADDED
@@ -0,0 +1,22 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Code borrowed from SelfMask: https://github.com/NoelShin/selfmask
3
+ """
4
+
5
+
6
+ class AverageMeter(object):
7
+ """Computes and stores the average and current value"""
8
+
9
+ def __init__(self):
10
+ self.reset()
11
+
12
+ def reset(self):
13
+ self.val = 0
14
+ self.avg = 0
15
+ self.sum = 0
16
+ self.count = 0
17
+
18
+ def update(self, val, n: int):
19
+ self.val = val
20
+ self.sum += val * n
21
+ self.count += n
22
+ self.avg = self.sum / self.count
evaluation/metrics/f_measure.py ADDED
@@ -0,0 +1,112 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Code borrowed from SelfMask: https://github.com/NoelShin/selfmask
3
+ """
4
+
5
+ import torch
6
+
7
+
8
+ class FMeasure:
9
+ def __init__(
10
+ self,
11
+ default_thres: float = 0.5,
12
+ beta_square: float = 0.3,
13
+ n_bins: int = 255,
14
+ eps: float = 1e-7,
15
+ ):
16
+ """
17
+ :param default_thres: a hyperparameter for F-measure that is used to binarize a predicted mask. Default: 0.5
18
+ :param beta_square: a hyperparameter for F-measure. Default: 0.3
19
+ :param n_bins: the number of thresholds that will be tested for F-max. Default: 255
20
+ :param eps: a small value for numerical stability
21
+ """
22
+
23
+ self.beta_square = beta_square
24
+ self.default_thres = default_thres
25
+ self.eps = eps
26
+ self.n_bins = n_bins
27
+
28
+ def _compute_precision_recall(
29
+ self, binary_pred_mask: torch.Tensor, gt_mask: torch.Tensor
30
+ ) -> torch.Tensor:
31
+ """
32
+ :param binary_pred_mask: (B x H x W) or (H x W)
33
+ :param gt_mask: (B x H x W) or (H x W), should be the same with binary_pred_mask
34
+ """
35
+ tp = torch.logical_and(binary_pred_mask, gt_mask).sum(dim=(-1, -2))
36
+ tp_fp = binary_pred_mask.sum(dim=(-1, -2))
37
+ tp_fn = gt_mask.sum(dim=(-1, -2))
38
+
39
+ prec = tp / (tp_fp + self.eps)
40
+ recall = tp / (tp_fn + self.eps)
41
+ return prec, recall
42
+
43
+ def _compute_f_measure(
44
+ self,
45
+ pred_mask: torch.Tensor,
46
+ gt_mask: torch.Tensor,
47
+ thresholds: torch.Tensor = None,
48
+ ) -> torch.Tensor:
49
+ if thresholds is None:
50
+ binary_pred_mask = pred_mask > self.default_thres
51
+ else:
52
+ binary_pred_mask = pred_mask > thresholds
53
+
54
+ prec, recall = self._compute_precision_recall(binary_pred_mask, gt_mask)
55
+ f_measure = ((1 + (self.beta_square**2)) * prec * recall) / (
56
+ (self.beta_square**2) * prec + recall + self.eps
57
+ )
58
+ return f_measure.cpu()
59
+
60
+ def _compute_f_max(
61
+ self, pred_mask: torch.Tensor, gt_mask: torch.Tensor
62
+ ) -> torch.Tensor:
63
+ """Compute self.n_bins + 1 F-measures, each of which has a different threshold, then return the maximum
64
+ F-measure among them.
65
+
66
+ :param pred_mask: (H x W)
67
+ :param gt_mask: (H x W)
68
+ """
69
+
70
+ # pred_masks, gt_masks: H x W -> self.n_bins x H x W
71
+ pred_masks = pred_mask.unsqueeze(dim=0).repeat(self.n_bins, 1, 1)
72
+ gt_masks = gt_mask.unsqueeze(dim=0).repeat(self.n_bins, 1, 1)
73
+
74
+ # thresholds: self.n_bins x 1 x 1
75
+ thresholds = (
76
+ torch.arange(0, 1, 1 / self.n_bins)
77
+ .view(self.n_bins, 1, 1)
78
+ .to(pred_masks.device)
79
+ )
80
+
81
+ # f_measures: self.n_bins
82
+ f_measures = self._compute_f_measure(pred_masks, gt_masks, thresholds)
83
+ return torch.max(f_measures).cpu(), f_measures
84
+
85
+ def _compute_f_mean(
86
+ self,
87
+ pred_mask: torch.Tensor,
88
+ gt_mask: torch.Tensor,
89
+ ) -> torch.Tensor:
90
+ adaptive_thres = 2 * pred_mask.mean(dim=(-1, -2), keepdim=True)
91
+ binary_pred_mask = pred_mask > adaptive_thres
92
+
93
+ prec, recall = self._compute_precision_recall(binary_pred_mask, gt_mask)
94
+ f_mean = ((1 + (self.beta_square**2)) * prec * recall) / (
95
+ (self.beta_square**2) * prec + recall + self.eps
96
+ )
97
+ return f_mean.cpu()
98
+
99
+ def __call__(self, pred_mask: torch.Tensor, gt_mask: torch.Tensor) -> dict:
100
+ """
101
+ :param pred_mask: (H x W) a normalized prediction mask with values in [0, 1]
102
+ :param gt_mask: (H x W) a binary ground truth mask with values in {0, 1}
103
+ :return: a dictionary with keys being "f_measure" and "f_max" and values being the respective values.
104
+ """
105
+ outputs: dict = dict()
106
+ for k in ("f_measure", "f_mean"):
107
+ outputs.update({k: getattr(self, f"_compute_{k}")(pred_mask, gt_mask)})
108
+
109
+ f_max_, all_f = self._compute_f_max(pred_mask, gt_mask)
110
+ outputs["f_max"] = f_max_
111
+ outputs["all_f"] = all_f # List of all f values for all thresholds
112
+ return outputs
evaluation/metrics/iou.py ADDED
@@ -0,0 +1,37 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Code adapted from SelfMask: https://github.com/NoelShin/selfmask
3
+ """
4
+
5
+ from typing import Optional, Union
6
+
7
+ import numpy as np
8
+ import torch
9
+
10
+
11
+ def compute_iou(
12
+ pred_mask: Union[np.ndarray, torch.Tensor],
13
+ gt_mask: Union[np.ndarray, torch.Tensor],
14
+ threshold: Optional[float] = 0.5,
15
+ eps: float = 1e-7,
16
+ ) -> Union[np.ndarray, torch.Tensor]:
17
+ """
18
+ :param pred_mask: (B x H x W) or (H x W)
19
+ :param gt_mask: (B x H x W) or (H x W), same shape with pred_mask
20
+ :param threshold: a binarization threshold
21
+ :param eps: a small value for computational stability
22
+ :return: (B) or (1)
23
+ """
24
+ assert pred_mask.shape == gt_mask.shape, f"{pred_mask.shape} != {gt_mask.shape}"
25
+ # assert 0. <= pred_mask.to(torch.float32).min() and pred_mask.max().to(torch.float32) <= 1., f"{pred_mask.min(), pred_mask.max()}"
26
+
27
+ if threshold is not None:
28
+ pred_mask = pred_mask > threshold
29
+ if isinstance(pred_mask, np.ndarray):
30
+ intersection = np.logical_and(pred_mask, gt_mask).sum(axis=(-1, -2))
31
+ union = np.logical_or(pred_mask, gt_mask).sum(axis=(-1, -2))
32
+ ious = intersection / (union + eps)
33
+ else:
34
+ intersection = torch.logical_and(pred_mask, gt_mask).sum(dim=(-1, -2))
35
+ union = torch.logical_or(pred_mask, gt_mask).sum(dim=(-1, -2))
36
+ ious = (intersection / (union + eps)).cpu()
37
+ return ious
evaluation/metrics/mae.py ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Code borrowed from SelfMask: https://github.com/NoelShin/selfmask
3
+ """
4
+
5
+ import torch
6
+
7
+
8
+ def compute_mae(pred_mask: torch.Tensor, gt_mask: torch.Tensor) -> torch.Tensor:
9
+ """
10
+ :param pred_mask: (H x W) or (B x H x W) a normalized prediction mask with values in [0, 1]
11
+ :param gt_mask: (H x W) or (B x H x W) a binary ground truth mask with values in {0, 1}
12
+ """
13
+ return torch.mean(
14
+ torch.abs(pred_mask - gt_mask.to(torch.float32)), dim=(-1, -2)
15
+ ).cpu()
evaluation/metrics/pixel_acc.py ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Code borrowed from SelfMask: https://github.com/NoelShin/selfmask
3
+ """
4
+
5
+ from typing import Optional
6
+
7
+ import torch
8
+
9
+
10
+ def compute_pixel_accuracy(
11
+ pred_mask: torch.Tensor, gt_mask: torch.Tensor, threshold: Optional[float] = 0.5
12
+ ) -> torch.Tensor:
13
+ """
14
+ :param pred_mask: (H x W) or (B x H x W) a normalized prediction mask with values in [0, 1]
15
+ :param gt_mask: (H x W) or (B x H x W) a binary ground truth mask with values in {0, 1}
16
+ """
17
+ if threshold is not None:
18
+ binary_pred_mask = pred_mask > threshold
19
+ else:
20
+ binary_pred_mask = pred_mask
21
+ return (binary_pred_mask == gt_mask).to(torch.float32).mean(dim=(-1, -2)).cpu()
evaluation/metrics/s_measure.py ADDED
@@ -0,0 +1,126 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # code borrowed from https://github.com/Hanqer/Evaluate-SOD/blob/master/evaluator.py
2
+ import numpy as np
3
+ import torch
4
+
5
+
6
+ class SMeasure:
7
+ def __init__(self, alpha: float = 0.5):
8
+ self.alpha: float = alpha
9
+ self.cuda: bool = True
10
+
11
+ def _centroid(self, gt):
12
+ rows, cols = gt.size()[-2:]
13
+ gt = gt.view(rows, cols)
14
+ if gt.sum() == 0:
15
+ if self.cuda:
16
+ X = torch.eye(1).cuda() * round(cols / 2)
17
+ Y = torch.eye(1).cuda() * round(rows / 2)
18
+ else:
19
+ X = torch.eye(1) * round(cols / 2)
20
+ Y = torch.eye(1) * round(rows / 2)
21
+ else:
22
+ total = gt.sum()
23
+ if self.cuda:
24
+ i = torch.from_numpy(np.arange(0, cols)).cuda().float()
25
+ j = torch.from_numpy(np.arange(0, rows)).cuda().float()
26
+ else:
27
+ i = torch.from_numpy(np.arange(0, cols)).float()
28
+ j = torch.from_numpy(np.arange(0, rows)).float()
29
+ X = torch.round((gt.sum(dim=0) * i).sum() / total)
30
+ Y = torch.round((gt.sum(dim=1) * j).sum() / total)
31
+ return X.long(), Y.long()
32
+
33
+ def _ssim(self, pred, gt):
34
+ gt = gt.float()
35
+ h, w = pred.size()[-2:]
36
+ N = h * w
37
+ x = pred.mean()
38
+ y = gt.mean()
39
+ sigma_x2 = ((pred - x) * (pred - x)).sum() / (N - 1 + 1e-20)
40
+ sigma_y2 = ((gt - y) * (gt - y)).sum() / (N - 1 + 1e-20)
41
+ sigma_xy = ((pred - x) * (gt - y)).sum() / (N - 1 + 1e-20)
42
+
43
+ aplha = 4 * x * y * sigma_xy
44
+ beta = (x * x + y * y) * (sigma_x2 + sigma_y2)
45
+
46
+ if aplha != 0:
47
+ Q = aplha / (beta + 1e-20)
48
+ elif aplha == 0 and beta == 0:
49
+ Q = 1.0
50
+ else:
51
+ Q = 0
52
+ return Q
53
+
54
+ def _object(self, pred, gt):
55
+ temp = pred[gt == 1]
56
+ x = temp.mean()
57
+ sigma_x = temp.std()
58
+ score = 2.0 * x / (x * x + 1.0 + sigma_x + 1e-20)
59
+
60
+ return score
61
+
62
+ def _s_object(self, pred, gt):
63
+ fg = torch.where(gt == 0, torch.zeros_like(pred), pred)
64
+ bg = torch.where(gt == 1, torch.zeros_like(pred), 1 - pred)
65
+ o_fg = self._object(fg, gt)
66
+ o_bg = self._object(bg, 1 - gt)
67
+ u = gt.mean()
68
+ Q = u * o_fg + (1 - u) * o_bg
69
+ return Q
70
+
71
+ def _divide_gt(self, gt, X, Y):
72
+ h, w = gt.size()[-2:]
73
+ area = h * w
74
+ gt = gt.view(h, w)
75
+ LT = gt[:Y, :X]
76
+ RT = gt[:Y, X:w]
77
+ LB = gt[Y:h, :X]
78
+ RB = gt[Y:h, X:w]
79
+ X = X.float()
80
+ Y = Y.float()
81
+ w1 = X * Y / area
82
+ w2 = (w - X) * Y / area
83
+ w3 = X * (h - Y) / area
84
+ w4 = 1 - w1 - w2 - w3
85
+ return LT, RT, LB, RB, w1, w2, w3, w4
86
+
87
+ def _divide_prediction(self, pred, X, Y):
88
+ h, w = pred.size()[-2:]
89
+ pred = pred.view(h, w)
90
+ LT = pred[:Y, :X]
91
+ RT = pred[:Y, X:w]
92
+ LB = pred[Y:h, :X]
93
+ RB = pred[Y:h, X:w]
94
+ return LT, RT, LB, RB
95
+
96
+ def _s_region(self, pred, gt):
97
+ X, Y = self._centroid(gt)
98
+ gt1, gt2, gt3, gt4, w1, w2, w3, w4 = self._divide_gt(gt, X, Y)
99
+ p1, p2, p3, p4 = self._divide_prediction(pred, X, Y)
100
+ Q1 = self._ssim(p1, gt1)
101
+ Q2 = self._ssim(p2, gt2)
102
+ Q3 = self._ssim(p3, gt3)
103
+ Q4 = self._ssim(p4, gt4)
104
+ Q = w1 * Q1 + w2 * Q2 + w3 * Q3 + w4 * Q4
105
+ # print(Q)
106
+ return Q
107
+
108
+ def __call__(self, pred_mask: torch.Tensor, gt_mask: torch.Tensor):
109
+ assert pred_mask.shape == gt_mask.shape
110
+ y = gt_mask.mean()
111
+ if y == 0:
112
+ x = pred_mask.mean()
113
+ Q = 1.0 - x
114
+ elif y == 1:
115
+ x = pred_mask.mean()
116
+ Q = x
117
+ else:
118
+ gt_mask[gt_mask >= 0.5] = 1
119
+ gt_mask[gt_mask < 0.5] = 0
120
+ # print(self._S_object(pred, gt), self._S_region(pred, gt))
121
+ Q = self.alpha * self._s_object(pred_mask, gt_mask) + (
122
+ 1 - self.alpha
123
+ ) * self._s_region(pred_mask, gt_mask)
124
+ if Q.item() < 0:
125
+ Q = torch.FloatTensor([0.0])
126
+ return Q.item()
evaluation/saliency.py ADDED
@@ -0,0 +1,272 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2022 - Valeo Comfort and Driving Assistance - valeo.ai
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import torch
16
+ import numpy as np
17
+ import torch.nn as nn
18
+ import torch.nn.functional as F
19
+
20
+ from tqdm import tqdm
21
+ from scipy import ndimage
22
+
23
+ from evaluation.metrics.average_meter import AverageMeter
24
+ from evaluation.metrics.f_measure import FMeasure
25
+ from evaluation.metrics.iou import compute_iou
26
+ from evaluation.metrics.mae import compute_mae
27
+ from evaluation.metrics.pixel_acc import compute_pixel_accuracy
28
+ from evaluation.metrics.s_measure import SMeasure
29
+
30
+ from misc import batch_apply_bilateral_solver
31
+
32
+
33
+ @torch.no_grad()
34
+ def write_metric_tf(writer, metrics, n_iter=-1, name=""):
35
+ writer.add_scalar(
36
+ f"Validation/{name}iou_pred",
37
+ metrics["ious"].avg,
38
+ n_iter,
39
+ )
40
+ writer.add_scalar(
41
+ f"Validation/{name}acc_pred",
42
+ metrics["pixel_accs"].avg,
43
+ n_iter,
44
+ )
45
+ writer.add_scalar(
46
+ f"Validation/{name}f_max",
47
+ metrics["f_maxs"].avg,
48
+ n_iter,
49
+ )
50
+
51
+
52
+ @torch.no_grad()
53
+ def eval_batch(batch_gt_masks, batch_pred_masks, metrics_res={}, reset=False):
54
+ """
55
+ Evaluation code adapted from SelfMask: https://github.com/NoelShin/selfmask
56
+ """
57
+
58
+ f_values = {}
59
+ # Keep track of f_values for each threshold
60
+ for i in range(255): # should equal n_bins in metrics/f_measure.py
61
+ f_values[i] = AverageMeter()
62
+
63
+ if metrics_res == {}:
64
+ metrics_res["f_scores"] = AverageMeter()
65
+ metrics_res["f_maxs"] = AverageMeter()
66
+ metrics_res["f_maxs_fixed"] = AverageMeter()
67
+ metrics_res["f_means"] = AverageMeter()
68
+ metrics_res["maes"] = AverageMeter()
69
+ metrics_res["ious"] = AverageMeter()
70
+ metrics_res["pixel_accs"] = AverageMeter()
71
+ metrics_res["s_measures"] = AverageMeter()
72
+
73
+ if reset:
74
+ metrics_res["f_scores"].reset()
75
+ metrics_res["f_maxs"].reset()
76
+ metrics_res["f_maxs_fixed"].reset()
77
+ metrics_res["f_means"].reset()
78
+ metrics_res["maes"].reset()
79
+ metrics_res["ious"].reset()
80
+ metrics_res["pixel_accs"].reset()
81
+ metrics_res["s_measures"].reset()
82
+
83
+ # iterate over batch dimension
84
+ for _, (pred_mask, gt_mask) in enumerate(zip(batch_pred_masks, batch_gt_masks)):
85
+ assert pred_mask.shape == gt_mask.shape, f"{pred_mask.shape} != {gt_mask.shape}"
86
+ assert len(pred_mask.shape) == len(gt_mask.shape) == 2
87
+ # Compute
88
+ # Binarize at 0.5 for IoU and pixel accuracy
89
+ binary_pred = (pred_mask > 0.5).float().squeeze()
90
+ iou = compute_iou(binary_pred, gt_mask)
91
+ f_measures = FMeasure()(pred_mask, gt_mask) # soft mask for F measure
92
+ mae = compute_mae(binary_pred, gt_mask)
93
+ pixel_acc = compute_pixel_accuracy(binary_pred, gt_mask)
94
+
95
+ # Update
96
+ metrics_res["ious"].update(val=iou.numpy(), n=1)
97
+ metrics_res["f_scores"].update(val=f_measures["f_measure"].numpy(), n=1)
98
+ metrics_res["f_maxs"].update(val=f_measures["f_max"].numpy(), n=1)
99
+ metrics_res["f_means"].update(val=f_measures["f_mean"].numpy(), n=1)
100
+ metrics_res["s_measures"].update(
101
+ val=SMeasure()(pred_mask=pred_mask, gt_mask=gt_mask.to(torch.float32)), n=1
102
+ )
103
+ metrics_res["maes"].update(val=mae.numpy(), n=1)
104
+ metrics_res["pixel_accs"].update(val=pixel_acc.numpy(), n=1)
105
+
106
+ # Keep track of f_values for each threshold
107
+ all_f = f_measures["all_f"].numpy()
108
+ for k, v in f_values.items():
109
+ v.update(val=all_f[k], n=1)
110
+ # Then compute the max for the f_max_fixed
111
+ metrics_res["f_maxs_fixed"].update(
112
+ val=np.max([v.avg for v in f_values.values()]), n=1
113
+ )
114
+
115
+ results = {}
116
+ # F-measure, F-max, F-mean, MAE, S-measure, IoU, pixel acc.
117
+ results["f_measure"] = metrics_res["f_scores"].avg
118
+ results["f_max"] = metrics_res["f_maxs"].avg
119
+ results["f_maxs_fixed"] = metrics_res["f_maxs_fixed"].avg
120
+ results["f_mean"] = metrics_res["f_means"].avg
121
+ results["s_measure"] = metrics_res["s_measures"].avg
122
+ results["mae"] = metrics_res["maes"].avg
123
+ results["iou"] = float(iou.numpy())
124
+ results["pixel_acc"] = metrics_res["pixel_accs"].avg
125
+
126
+ return results, metrics_res
127
+
128
+
129
+ def evaluate_saliency(
130
+ dataset,
131
+ model,
132
+ writer=None,
133
+ batch_size=1,
134
+ n_iter=-1,
135
+ apply_bilateral=False,
136
+ im_fullsize=True,
137
+ method="pred", # can also be "bkg",
138
+ apply_weights: bool = True,
139
+ evaluation_mode: str = "single", # choices are ["single", "multi"]
140
+ ):
141
+
142
+ if im_fullsize:
143
+ # Change transformation
144
+ dataset.fullimg_mode()
145
+ batch_size = 1
146
+
147
+ valloader = torch.utils.data.DataLoader(
148
+ dataset, batch_size=batch_size, shuffle=False, num_workers=2
149
+ )
150
+
151
+ sigmoid = nn.Sigmoid()
152
+
153
+ metrics_res = {}
154
+ metrics_res_bs = {}
155
+ valbar = tqdm(enumerate(valloader, 0), leave=None)
156
+ for i, data in valbar:
157
+ inputs, _, _, _, _, gt_labels, _ = data
158
+ inputs = inputs.to("cuda")
159
+ gt_labels = gt_labels.to("cuda").float()
160
+
161
+ # Forward step
162
+ with torch.no_grad():
163
+ preds = model(inputs, for_eval=True)
164
+
165
+ h, w = gt_labels.shape[-2:]
166
+ preds_up = F.interpolate(
167
+ preds,
168
+ scale_factor=model.vit_patch_size,
169
+ mode="bilinear",
170
+ align_corners=False,
171
+ )[..., :h, :w]
172
+ soft_preds = sigmoid(preds_up.detach()).squeeze(0)
173
+ preds_up = (sigmoid(preds_up.detach()) > 0.5).squeeze(0).float()
174
+
175
+ reset = True if i == 0 else False
176
+ if evaluation_mode == "single":
177
+ labeled, nr_objects = ndimage.label(preds_up.squeeze().cpu().numpy())
178
+ if nr_objects == 0:
179
+ preds_up_one_cc = preds_up.squeeze()
180
+ print("nr_objects == 0")
181
+ else:
182
+ nb_pixel = [np.sum(labeled == i) for i in range(nr_objects + 1)]
183
+ pixel_order = np.argsort(nb_pixel)
184
+
185
+ cc = [torch.Tensor(labeled == i) for i in pixel_order]
186
+ cc = torch.stack(cc).cuda()
187
+
188
+ # Find CC set as background, here not necessarily the biggest
189
+ cc_background = (
190
+ (
191
+ (
192
+ (~(preds_up[None, :, :, :].bool())).float()
193
+ + cc[:, None, :, :].cuda()
194
+ )
195
+ > 1
196
+ )
197
+ .sum(-1)
198
+ .sum(-1)
199
+ .argmax()
200
+ )
201
+ pixel_order = np.delete(pixel_order, int(cc_background.cpu().numpy()))
202
+
203
+ preds_up_one_cc = torch.Tensor(labeled == pixel_order[-1]).cuda()
204
+
205
+ _, metrics_res = eval_batch(
206
+ gt_labels,
207
+ preds_up_one_cc.unsqueeze(0),
208
+ metrics_res=metrics_res,
209
+ reset=reset,
210
+ )
211
+
212
+ elif evaluation_mode == "multi":
213
+ # Eval without bilateral solver
214
+ _, metrics_res = eval_batch(
215
+ gt_labels,
216
+ soft_preds.unsqueeze(0) if len(soft_preds.shape) == 2 else soft_preds,
217
+ metrics_res=metrics_res,
218
+ reset=reset,
219
+ ) # soft preds needed for F beta measure
220
+
221
+ # Apply bilateral solver
222
+ preds_bs = None
223
+ if apply_bilateral:
224
+ get_all_cc = True if evaluation_mode == "multi" else False
225
+ preds_bs, _ = batch_apply_bilateral_solver(
226
+ data, preds_up.detach(), get_all_cc=get_all_cc
227
+ )
228
+
229
+ _, metrics_res_bs = eval_batch(
230
+ gt_labels,
231
+ preds_bs[None, :, :].float(),
232
+ metrics_res=metrics_res_bs,
233
+ reset=reset,
234
+ )
235
+
236
+ bar_str = (
237
+ f"{dataset.name} | {evaluation_mode} mode | "
238
+ f"F-max {metrics_res['f_maxs'].avg:.3f} "
239
+ f"IoU {metrics_res['ious'].avg:.3f}, "
240
+ f"PA {metrics_res['pixel_accs'].avg:.3f}"
241
+ )
242
+
243
+ if apply_bilateral:
244
+ bar_str += (
245
+ f" | with bilateral solver: "
246
+ f"F-max {metrics_res_bs['f_maxs'].avg:.3f}, "
247
+ f"IoU {metrics_res_bs['ious'].avg:.3f}, "
248
+ f"PA. {metrics_res_bs['pixel_accs'].avg:.3f}"
249
+ )
250
+
251
+ valbar.set_description(bar_str)
252
+
253
+ # Writing in tensorboard
254
+ if writer is not None:
255
+ write_metric_tf(
256
+ writer,
257
+ metrics_res,
258
+ n_iter=n_iter,
259
+ name=f"{dataset.name}_{evaluation_mode}_",
260
+ )
261
+
262
+ if apply_bilateral:
263
+ write_metric_tf(
264
+ writer,
265
+ metrics_res_bs,
266
+ n_iter=n_iter,
267
+ name=f"{dataset.name}_{evaluation_mode}-BS_",
268
+ )
269
+
270
+ # Go back to original transformation
271
+ if im_fullsize:
272
+ dataset.training_mode()
evaluation/uod.py ADDED
@@ -0,0 +1,117 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2021 - Valeo Comfort and Driving Assistance - Oriane Siméoni @ valeo.ai
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ """
16
+ Code adapted from previous method LOST: https://github.com/valeoai/LOST
17
+ """
18
+
19
+ import os
20
+ import time
21
+ import torch
22
+ import torch.nn as nn
23
+ import numpy as np
24
+
25
+ from tqdm import tqdm
26
+ from misc import bbox_iou, get_bbox_from_segmentation_labels
27
+
28
+
29
+ def evaluation_unsupervised_object_discovery(
30
+ dataset,
31
+ model,
32
+ evaluation_mode: str = "single", # choices are ["single", "multi"]
33
+ output_dir: str = "outputs",
34
+ no_hards: bool = False,
35
+ ):
36
+
37
+ assert evaluation_mode == "single"
38
+
39
+ sigmoid = nn.Sigmoid()
40
+
41
+ # ----------------------------------------------------
42
+ # Loop over images
43
+ preds_dict = {}
44
+ cnt = 0
45
+ corloc = np.zeros(len(dataset.dataloader))
46
+
47
+ start_time = time.time()
48
+ pbar = tqdm(dataset.dataloader)
49
+ for im_id, inp in enumerate(pbar):
50
+
51
+ # ------------ IMAGE PROCESSING -------------------------------------------
52
+ img = inp[0]
53
+
54
+ init_image_size = img.shape
55
+
56
+ # Get the name of the image
57
+ im_name = dataset.get_image_name(inp[1])
58
+ # Pass in case of no gt boxes in the image
59
+ if im_name is None:
60
+ continue
61
+
62
+ # Padding the image with zeros to fit multiple of patch-size
63
+ size_im = (
64
+ img.shape[0],
65
+ int(np.ceil(img.shape[1] / model.vit_patch_size) * model.vit_patch_size),
66
+ int(np.ceil(img.shape[2] / model.vit_patch_size) * model.vit_patch_size),
67
+ )
68
+ paded = torch.zeros(size_im)
69
+ paded[:, : img.shape[1], : img.shape[2]] = img
70
+ img = paded
71
+
72
+ # # Move to gpu
73
+ img = img.cuda(non_blocking=True)
74
+
75
+ # Size for transformers
76
+ # w_featmap = img.shape[-2] // model.vit_patch_size
77
+ # h_featmap = img.shape[-1] // model.vit_patch_size
78
+
79
+ # ------------ GROUND-TRUTH -------------------------------------------
80
+ gt_bbxs, gt_cls = dataset.extract_gt(inp[1], im_name)
81
+
82
+ if gt_bbxs is not None:
83
+ # Discard images with no gt annotations
84
+ # Happens only in the case of VOC07 and VOC12
85
+ if gt_bbxs.shape[0] == 0 and no_hards:
86
+ continue
87
+
88
+ outputs = model(img[None, :, :, :])
89
+ preds = (sigmoid(outputs[0].detach()) > 0.5).float().squeeze().cpu().numpy()
90
+
91
+ # get bbox
92
+ pred = get_bbox_from_segmentation_labels(
93
+ segmenter_predictions=preds,
94
+ scales=[model.vit_patch_size, model.vit_patch_size],
95
+ initial_image_size=init_image_size[1:],
96
+ )
97
+
98
+ # ------------ Visualizations -------------------------------------------
99
+ # Save the prediction
100
+ preds_dict[im_name] = pred
101
+
102
+ # Compare prediction to GT boxes
103
+ ious = bbox_iou(torch.from_numpy(pred), torch.from_numpy(gt_bbxs))
104
+
105
+ if torch.any(ious >= 0.5):
106
+ corloc[im_id] = 1
107
+
108
+ cnt += 1
109
+ if cnt % 50 == 0:
110
+ pbar.set_description(f"Peekaboo {int(np.sum(corloc))}/{cnt}")
111
+
112
+ # Evaluate
113
+ print(f"corloc: {100*np.sum(corloc)/cnt:.2f} ({int(np.sum(corloc))}/{cnt})")
114
+ result_file = os.path.join(output_dir, "uod_results.txt")
115
+ with open(result_file, "w") as f:
116
+ f.write("corloc,%.1f,,\n" % (100 * np.sum(corloc) / cnt))
117
+ print("File saved at %s" % result_file)
format_codebase.sh ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/bin/sh
2
+
3
+ # Script to format codebase
4
+
5
+ # pip install autopep8
6
+ # pip install --force-reinstall --upgrade typed-ast black
7
+
8
+ # Run autopep8 to fix specific PEP 8 issues
9
+ autopep8 --in-place --recursive --select=E1,E2,E3,W1,W2 ./**.py
10
+
11
+ # Run black to enforce consistent formatting
12
+ black ./
13
+
14
+ # To run this file
15
+ # chmod +x format_codebase.sh
16
+ # ./format_codebase.sh
media/description.html ADDED
@@ -0,0 +1,22 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ <!DOCTYPE html>
2
+ <html lang="en">
3
+ <head>
4
+ <meta charset="UTF-8">
5
+ <title>Title</title>
6
+ </head>
7
+ <body>
8
+ Try this demo for <a href="https://github.com/hasibzunair/peekaboo">PEEKABOO</a>,
9
+ introduced in our <strong>BMVC'2024</strong> paper <a href="https://arxiv.org/abs/2407.17628">PEEKABOO: Hiding Parts of an Image for Unsupervised Object Localization</a>.
10
+ </br>
11
+ Peekaboo aims to explicitly model contextual relationship among pixels through image masking for unsupervised object localization.
12
+ In a self-supervised procedure (i.e. pretext task) without any additional training (i.e. downstream task), context-based representation learning is done at both
13
+ the pixel-level by making predictions on masked images and at shape-level by matching the predictions of the masked input to the unmasked one.
14
+ </br>
15
+ You can use this demo to segment the most salient object(s) in your images. To use it, simply
16
+ upload an image of your choice and hit submit. You will get one or more segmentation maps of the most salient objects present
17
+ in your images.
18
+ </br>
19
+ <a href="https://hasibzunair.github.io/peekaboo/"><strong>Project Page</strong></a>
20
+ </br>
21
+ </body>
22
+ </html>
misc.py ADDED
@@ -0,0 +1,337 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Code for Peekaboo
2
+ # Author: Hasib Zunair
3
+ # Modified from https://github.com/valeoai/FOUND, see license below.
4
+
5
+ # Copyright 2022 - Valeo Comfort and Driving Assistance - Oriane Siméoni @ valeo.ai
6
+ #
7
+ # Licensed under the Apache License, Version 2.0 (the "License");
8
+ # you may not use this file except in compliance with the License.
9
+ # You may obtain a copy of the License at
10
+ #
11
+ # http://www.apache.org/licenses/LICENSE-2.0
12
+ #
13
+ # Unless required by applicable law or agreed to in writing, software
14
+ # distributed under the License is distributed on an "AS IS" BASIS,
15
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
16
+ # See the License for the specific language governing permissions and
17
+ # limitations under the License.
18
+
19
+ """Helpers functions"""
20
+
21
+ import re
22
+ import os
23
+ import cv2
24
+ import sys
25
+ import os.path as osp
26
+ import errno
27
+ import yaml
28
+ import math
29
+ import random
30
+ import scipy.ndimage
31
+ import numpy as np
32
+
33
+ import torch
34
+ import torch.nn.functional as F
35
+
36
+ from typing import List
37
+ from torchvision import transforms as T
38
+
39
+ from bilateral_solver import bilateral_solver_output
40
+
41
+
42
+ loader = yaml.SafeLoader
43
+ loader.add_implicit_resolver(
44
+ "tag:yaml.org,2002:float",
45
+ re.compile(
46
+ """^(?:
47
+ [-+]?(?:[0-9][0-9_]*)\\.[0-9_]*(?:[eE][-+]?[0-9]+)?
48
+ |[-+]?(?:[0-9][0-9_]*)(?:[eE][-+]?[0-9]+)
49
+ |\\.[0-9_]+(?:[eE][-+][0-9]+)?
50
+ |[-+]?[0-9][0-9_]*(?::[0-5]?[0-9])+\\.[0-9_]*
51
+ |[-+]?\\.(?:inf|Inf|INF)
52
+ |\\.(?:nan|NaN|NAN))$""",
53
+ re.X,
54
+ ),
55
+ list("-+0123456789."),
56
+ )
57
+
58
+
59
+ def mkdir_if_missing(directory):
60
+ if not osp.exists(directory):
61
+ try:
62
+ os.makedirs(directory)
63
+ except OSError as e:
64
+ if e.errno != errno.EEXIST:
65
+ raise
66
+
67
+
68
+ class Logger(object):
69
+ """
70
+ Write console output to external text file.
71
+ Code imported from https://github.com/Cysu/open-reid/blob/master/reid/utils/logging.py.
72
+ """
73
+
74
+ def __init__(self, fpath=None):
75
+ self.console = sys.stdout
76
+ self.file = None
77
+ if fpath is not None:
78
+ mkdir_if_missing(os.path.dirname(fpath))
79
+ self.file = open(fpath, "w")
80
+
81
+ def __del__(self):
82
+ self.close()
83
+
84
+ def __enter__(self):
85
+ pass
86
+
87
+ def __exit__(self, *args):
88
+ self.close()
89
+
90
+ def write(self, msg):
91
+ self.console.write(msg)
92
+ if self.file is not None:
93
+ self.file.write(msg)
94
+
95
+ def flush(self):
96
+ self.console.flush()
97
+ if self.file is not None:
98
+ self.file.flush()
99
+ os.fsync(self.file.fileno())
100
+
101
+ def close(self):
102
+ self.console.close()
103
+ if self.file is not None:
104
+ self.file.close()
105
+
106
+
107
+ class Struct:
108
+ def __init__(self, **entries):
109
+ self.__dict__.update(entries)
110
+
111
+
112
+ def load_config(config_file):
113
+ with open(config_file, errors="ignore") as f:
114
+ # conf = yaml.safe_load(f) # load config
115
+ conf = yaml.load(f, Loader=loader)
116
+ print("hyperparameters: " + ", ".join(f"{k}={v}" for k, v in conf.items()))
117
+
118
+ # TODO yaml_save(save_dir / 'config.yaml', conf)
119
+ return Struct(**conf), conf # conf returned to print it
120
+
121
+
122
+ def set_seed(seed: int) -> None:
123
+ """
124
+ Set all seeds to make results reproducible
125
+ """
126
+ # env
127
+ os.environ["PYTHONHASHSEED"] = str(seed)
128
+
129
+ # python
130
+ random.seed(seed)
131
+
132
+ # numpy
133
+ np.random.seed(seed)
134
+
135
+ # torch
136
+ torch.manual_seed(seed)
137
+ torch.cuda.manual_seed(0)
138
+ torch.cuda.manual_seed_all(seed)
139
+ if torch.cuda.is_available():
140
+ torch.backends.cudnn.deterministic = True
141
+ torch.backends.cudnn.benchmark = True
142
+
143
+
144
+ def IoU(mask1, mask2):
145
+ """
146
+ Code adapted from TokenCut: https://github.com/YangtaoWANG95/TokenCut
147
+ """
148
+ mask1, mask2 = (mask1 > 0.5).to(torch.bool), (mask2 > 0.5).to(torch.bool)
149
+ intersection = torch.sum(mask1 * (mask1 == mask2), dim=[-1, -2]).squeeze()
150
+ union = torch.sum(mask1 + mask2, dim=[-1, -2]).squeeze()
151
+ return (intersection.to(torch.float) / union).mean().item()
152
+
153
+
154
+ def batch_apply_bilateral_solver(data, masks, get_all_cc=True, shape=None):
155
+
156
+ cnt_bs = 0
157
+ masks_bs = []
158
+
159
+ # inputs, init_imgs, gt_labels, img_path = data
160
+ inputs, _, _, init_imgs, _, gt_labels, img_path = data
161
+
162
+ for id in range(inputs.shape[0]):
163
+ _, bs_mask, use_bs = apply_bilateral_solver(
164
+ mask=masks[id].squeeze().cpu().numpy(),
165
+ img=init_imgs[id],
166
+ img_path=img_path[id],
167
+ im_fullsize=False,
168
+ # Careful shape should be opposed
169
+ shape=(gt_labels.shape[-1], gt_labels.shape[-2]),
170
+ get_all_cc=get_all_cc,
171
+ )
172
+ cnt_bs += use_bs
173
+
174
+ # use the bilateral solver output if IoU > 0.5
175
+ if use_bs:
176
+ if shape is None:
177
+ shape = masks.shape[-2:]
178
+ # Interpolate to downsample the mask back
179
+ bs_ds = F.interpolate(
180
+ torch.Tensor(bs_mask).unsqueeze(0).unsqueeze(0),
181
+ shape, # TODO check here
182
+ mode="bilinear",
183
+ align_corners=False,
184
+ )
185
+ masks_bs.append(bs_ds.bool().cuda().squeeze()[None, :, :])
186
+ else:
187
+ # Use initial mask
188
+ masks_bs.append(masks[id].cuda().squeeze()[None, :, :])
189
+
190
+ return torch.cat(masks_bs).squeeze(), cnt_bs
191
+
192
+
193
+ def apply_bilateral_solver(
194
+ mask,
195
+ img,
196
+ img_path,
197
+ shape,
198
+ im_fullsize=False,
199
+ get_all_cc=False,
200
+ bs_iou_threshold: float = 0.5,
201
+ reshape: bool = True,
202
+ ):
203
+ # Get initial image in the case of using full image
204
+ img_init = None
205
+ if not im_fullsize:
206
+ # Use the image given by dataloader
207
+ shape = (img.shape[-1], img.shape[-2])
208
+ t = T.ToPILImage()
209
+ img_init = t(img)
210
+
211
+ if reshape:
212
+ # Resize predictions to image size
213
+ resized_mask = cv2.resize(mask, shape)
214
+ sel_obj_mask = resized_mask
215
+ else:
216
+ resized_mask = mask
217
+ sel_obj_mask = mask
218
+
219
+ # Apply bilinear solver
220
+ _, binary_solver = bilateral_solver_output(
221
+ img_path,
222
+ resized_mask,
223
+ img=img_init,
224
+ sigma_spatial=16,
225
+ sigma_luma=16,
226
+ sigma_chroma=8,
227
+ get_all_cc=get_all_cc,
228
+ )
229
+
230
+ mask1 = torch.from_numpy(resized_mask).cuda()
231
+ mask2 = torch.from_numpy(binary_solver).cuda().float()
232
+
233
+ use_bs = 0
234
+ # If enough overlap, use BS output
235
+ if IoU(mask1, mask2) > bs_iou_threshold:
236
+ sel_obj_mask = binary_solver.astype(float)
237
+ use_bs = 1
238
+
239
+ return resized_mask, sel_obj_mask, use_bs
240
+
241
+
242
+ def get_bbox_from_segmentation_labels(
243
+ segmenter_predictions: torch.Tensor,
244
+ initial_image_size: torch.Size,
245
+ scales: List[int],
246
+ ) -> np.array:
247
+ """
248
+ Find the largest connected component in foreground, extract its bounding box
249
+ """
250
+ objects, num_objects = scipy.ndimage.label(segmenter_predictions)
251
+
252
+ # find biggest connected component
253
+ all_foreground_labels = objects.flatten()[objects.flatten() != 0]
254
+ most_frequent_label = np.bincount(all_foreground_labels).argmax()
255
+ mask = np.where(objects == most_frequent_label)
256
+ # Add +1 because excluded max
257
+ ymin, ymax = min(mask[0]), max(mask[0]) + 1
258
+ xmin, xmax = min(mask[1]), max(mask[1]) + 1
259
+
260
+ if initial_image_size == segmenter_predictions.shape:
261
+ # Masks are already upsampled
262
+ pred = [xmin, ymin, xmax, ymax]
263
+ else:
264
+ # Rescale to image size
265
+ r_xmin, r_xmax = scales[1] * xmin, scales[1] * xmax
266
+ r_ymin, r_ymax = scales[0] * ymin, scales[0] * ymax
267
+ pred = [r_xmin, r_ymin, r_xmax, r_ymax]
268
+
269
+ # Check not out of image size (used when padding)
270
+ if initial_image_size:
271
+ pred[2] = min(pred[2], initial_image_size[1])
272
+ pred[3] = min(pred[3], initial_image_size[0])
273
+
274
+ return np.asarray(pred)
275
+
276
+
277
+ def bbox_iou(
278
+ box1: np.array,
279
+ box2: np.array,
280
+ x1y1x2y2: bool = True,
281
+ GIoU: bool = False,
282
+ DIoU: bool = False,
283
+ CIoU: bool = False,
284
+ eps: float = 1e-7,
285
+ ):
286
+ # https://github.com/ultralytics/yolov5/blob/develop/utils/general.py
287
+ # Returns the IoU of box1 to box2. box1 is 4, box2 is nx4
288
+ box2 = box2.T
289
+
290
+ # Get the coordinates of bounding boxes
291
+ if x1y1x2y2: # x1, y1, x2, y2 = box1
292
+ b1_x1, b1_y1, b1_x2, b1_y2 = box1[0], box1[1], box1[2], box1[3]
293
+ b2_x1, b2_y1, b2_x2, b2_y2 = box2[0], box2[1], box2[2], box2[3]
294
+ else: # transform from xywh to xyxy
295
+ b1_x1, b1_x2 = box1[0] - box1[2] / 2, box1[0] + box1[2] / 2
296
+ b1_y1, b1_y2 = box1[1] - box1[3] / 2, box1[1] + box1[3] / 2
297
+ b2_x1, b2_x2 = box2[0] - box2[2] / 2, box2[0] + box2[2] / 2
298
+ b2_y1, b2_y2 = box2[1] - box2[3] / 2, box2[1] + box2[3] / 2
299
+
300
+ # Intersection area
301
+ inter = (torch.min(b1_x2, b2_x2) - torch.max(b1_x1, b2_x1)).clamp(0) * (
302
+ torch.min(b1_y2, b2_y2) - torch.max(b1_y1, b2_y1)
303
+ ).clamp(0)
304
+
305
+ # Union Area
306
+ w1, h1 = b1_x2 - b1_x1, b1_y2 - b1_y1 + eps
307
+ w2, h2 = b2_x2 - b2_x1, b2_y2 - b2_y1 + eps
308
+ union = w1 * h1 + w2 * h2 - inter + eps
309
+
310
+ iou = inter / union
311
+ if GIoU or DIoU or CIoU:
312
+ cw = torch.max(b1_x2, b2_x2) - torch.min(
313
+ b1_x1, b2_x1
314
+ ) # convex (smallest enclosing box) width
315
+ ch = torch.max(b1_y2, b2_y2) - torch.min(b1_y1, b2_y1) # convex height
316
+ if CIoU or DIoU: # Distance or Complete IoU https://arxiv.org/abs/1911.08287v1
317
+ c2 = cw**2 + ch**2 + eps # convex diagonal squared
318
+ rho2 = (
319
+ (b2_x1 + b2_x2 - b1_x1 - b1_x2) ** 2
320
+ + (b2_y1 + b2_y2 - b1_y1 - b1_y2) ** 2
321
+ ) / 4 # center distance squared
322
+ if DIoU:
323
+ return iou - rho2 / c2 # DIoU
324
+ elif (
325
+ CIoU
326
+ ): # https://github.com/Zzh-tju/DIoU-SSD-pytorch/blob/master/utils/box/box_utils.py#L47
327
+ v = (4 / math.pi**2) * torch.pow(
328
+ torch.atan(w2 / h2) - torch.atan(w1 / h1), 2
329
+ )
330
+ with torch.no_grad():
331
+ alpha = v / (v - iou + (1 + eps))
332
+ return iou - (rho2 / c2 + v * alpha) # CIoU
333
+ else: # GIoU https://arxiv.org/pdf/1902.09630.pdf
334
+ c_area = cw * ch + eps # convex area
335
+ return iou - (c_area - union) / c_area # GIoU
336
+ else:
337
+ return iou # IoU
model.py ADDED
@@ -0,0 +1,180 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Code for Peekaboo
2
+ # Author: Hasib Zunair
3
+ # Modified from https://github.com/valeoai/FOUND, see license below.
4
+
5
+ # Copyright 2022 - Valeo Comfort and Driving Assistance - Oriane Siméoni @ valeo.ai
6
+ #
7
+ # Licensed under the Apache License, Version 2.0 (the "License");
8
+ # you may not use this file except in compliance with the License.
9
+ # You may obtain a copy of the License at
10
+ #
11
+ # http://www.apache.org/licenses/LICENSE-2.0
12
+ #
13
+ # Unless required by applicable law or agreed to in writing, software
14
+ # distributed under the License is distributed on an "AS IS" BASIS,
15
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
16
+ # See the License for the specific language governing permissions and
17
+ # limitations under the License.
18
+
19
+ """Model code for Peekaboo"""
20
+
21
+ import os
22
+ import torch
23
+ import torch.nn as nn
24
+ import dino.vision_transformer as vits
25
+
26
+
27
+ class PeekabooModel(nn.Module):
28
+ def __init__(
29
+ self,
30
+ vit_model="dino",
31
+ vit_arch="vit_small",
32
+ vit_patch_size=8,
33
+ enc_type_feats="k",
34
+ ):
35
+
36
+ super(PeekabooModel, self).__init__()
37
+
38
+ ########## Encoder ##########
39
+ self.vit_encoder, self.initial_dim, self.hook_features = get_vit_encoder(
40
+ vit_arch, vit_model, vit_patch_size, enc_type_feats
41
+ )
42
+ self.vit_patch_size = vit_patch_size
43
+ self.enc_type_feats = enc_type_feats
44
+
45
+ ########## Decoder ##########
46
+ self.previous_dim = self.initial_dim
47
+ self.decoder = nn.Conv2d(self.previous_dim, 1, (1, 1))
48
+
49
+ def _make_input_divisible(self, x: torch.Tensor) -> torch.Tensor:
50
+ # From selfmask
51
+ """Pad some pixels to make the input size divisible by the patch size."""
52
+ B, _, H_0, W_0 = x.shape
53
+ pad_w = (self.vit_patch_size - W_0 % self.vit_patch_size) % self.vit_patch_size
54
+ pad_h = (self.vit_patch_size - H_0 % self.vit_patch_size) % self.vit_patch_size
55
+
56
+ x = nn.functional.pad(x, (0, pad_w, 0, pad_h), value=0)
57
+ return x
58
+
59
+ def forward(self, batch, decoder=None, for_eval=False):
60
+
61
+ # Make the image divisible by the patch size
62
+ if for_eval:
63
+ batch = self._make_input_divisible(batch)
64
+ _w, _h = batch.shape[-2:]
65
+ _h, _w = _h // self.vit_patch_size, _w // self.vit_patch_size
66
+ else:
67
+ # Cropping used during training, could be changed to improve
68
+ w, h = (
69
+ batch.shape[-2] - batch.shape[-2] % self.vit_patch_size,
70
+ batch.shape[-1] - batch.shape[-1] % self.vit_patch_size,
71
+ )
72
+ batch = batch[:, :, :w, :h]
73
+
74
+ w_featmap = batch.shape[-2] // self.vit_patch_size
75
+ h_featmap = batch.shape[-1] // self.vit_patch_size
76
+
77
+ # Forward pass
78
+ with torch.no_grad():
79
+ # Encoder forward pass
80
+ att = self.vit_encoder.get_last_selfattention(batch)
81
+
82
+ # Get decoder features
83
+ feats = self.extract_feats(dims=att.shape, type_feats=self.enc_type_feats)
84
+ feats = feats[:, 1:, :, :].reshape(att.shape[0], w_featmap, h_featmap, -1)
85
+ feats = feats.permute(0, 3, 1, 2)
86
+
87
+ # Apply decoder
88
+ if decoder is None:
89
+ decoder = self.decoder
90
+
91
+ logits = decoder(feats)
92
+ return logits
93
+
94
+ @torch.no_grad()
95
+ def decoder_load_weights(self, weights_path):
96
+ print(f"Loading model from weights {weights_path}.")
97
+ # Load states
98
+ if torch.cuda.is_available():
99
+ state_dict = torch.load(weights_path)
100
+ else:
101
+ state_dict = torch.load(weights_path, map_location=torch.device("cpu"))
102
+
103
+ # Decoder
104
+ self.decoder.load_state_dict(state_dict["decoder"])
105
+ self.decoder.eval()
106
+ self.decoder.to(torch.device("cuda" if torch.cuda.is_available() else "cpu"))
107
+
108
+ @torch.no_grad()
109
+ def decoder_save_weights(self, save_dir, n_iter):
110
+ state_dict = {}
111
+ state_dict["decoder"] = self.decoder.state_dict()
112
+ fname = os.path.join(save_dir, f"decoder_weights_niter{n_iter}.pt")
113
+ torch.save(state_dict, fname)
114
+ print(f"\n----" f"\nModel saved at {fname}")
115
+
116
+ @torch.no_grad()
117
+ def extract_feats(self, dims, type_feats="k"):
118
+
119
+ nb_im, nh, nb_tokens, _ = dims
120
+ qkv = (
121
+ self.hook_features["qkv"]
122
+ .reshape(nb_im, nb_tokens, 3, nh, -1 // nh) # 3 corresponding to |qkv|
123
+ .permute(2, 0, 3, 1, 4)
124
+ )
125
+
126
+ q, k, v = qkv[0], qkv[1], qkv[2]
127
+
128
+ if type_feats == "q":
129
+ return q.transpose(1, 2).float()
130
+ elif type_feats == "k":
131
+ return k.transpose(1, 2).float()
132
+ elif type_feats == "v":
133
+ return v.transpose(1, 2).float()
134
+ else:
135
+ raise ValueError("Unknown features")
136
+
137
+
138
+ def get_vit_encoder(vit_arch, vit_model, vit_patch_size, enc_type_feats):
139
+ if vit_arch == "vit_small" and vit_patch_size == 16:
140
+ url = "dino_deitsmall16_pretrain/dino_deitsmall16_pretrain.pth"
141
+ initial_dim = 384
142
+ elif vit_arch == "vit_small" and vit_patch_size == 8:
143
+ url = "dino_deitsmall8_300ep_pretrain/dino_deitsmall8_300ep_pretrain.pth"
144
+ initial_dim = 384
145
+ elif vit_arch == "vit_base" and vit_patch_size == 16:
146
+ if vit_model == "clip":
147
+ url = "5806e77cd80f8b59890b7e101eabd078d9fb84e6937f9e85e4ecb61988df416f/ViT-B-16.pt"
148
+ elif vit_model == "dino":
149
+ url = "dino_vitbase16_pretrain/dino_vitbase16_pretrain.pth"
150
+ initial_dim = 768
151
+ elif vit_arch == "vit_base" and vit_patch_size == 8:
152
+ url = "dino_vitbase8_pretrain/dino_vitbase8_pretrain.pth"
153
+ initial_dim = 768
154
+
155
+ if vit_model == "dino":
156
+ vit_encoder = vits.__dict__[vit_arch](patch_size=vit_patch_size, num_classes=0)
157
+ # TODO change if want to have last layer not unfrozen
158
+ for p in vit_encoder.parameters():
159
+ p.requires_grad = False
160
+ vit_encoder.eval().to(
161
+ torch.device("cuda" if torch.cuda.is_available() else "cpu")
162
+ ) # mode eval
163
+ state_dict = torch.hub.load_state_dict_from_url(
164
+ url="https://dl.fbaipublicfiles.com/dino/" + url
165
+ )
166
+ vit_encoder.load_state_dict(state_dict, strict=True)
167
+
168
+ hook_features = {}
169
+ if enc_type_feats in ["k", "q", "v", "qkv", "mlp"]:
170
+ # Define the hook
171
+ def hook_fn_forward_qkv(module, input, output):
172
+ hook_features["qkv"] = output
173
+
174
+ vit_encoder._modules["blocks"][-1]._modules["attn"]._modules[
175
+ "qkv"
176
+ ].register_forward_hook(hook_fn_forward_qkv)
177
+ else:
178
+ raise ValueError("Not implemented.")
179
+
180
+ return vit_encoder, initial_dim, hook_features
notebooks/exp.ipynb ADDED
@@ -0,0 +1,434 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "code",
5
+ "execution_count": null,
6
+ "metadata": {},
7
+ "outputs": [],
8
+ "source": [
9
+ "print(\"hello\")"
10
+ ]
11
+ },
12
+ {
13
+ "cell_type": "code",
14
+ "execution_count": null,
15
+ "metadata": {},
16
+ "outputs": [],
17
+ "source": [
18
+ "import os\n",
19
+ "import torch\n",
20
+ "import argparse\n",
21
+ "import torch.nn as nn\n",
22
+ "import torch.nn.functional as F\n",
23
+ "import matplotlib.pyplot as plt\n",
24
+ "\n",
25
+ "os.chdir(\"..\")\n",
26
+ "\n",
27
+ "from PIL import Image\n",
28
+ "from model import FoundModel\n",
29
+ "from misc import load_config\n",
30
+ "from torchvision import transforms as T\n",
31
+ "\n",
32
+ "NORMALIZE = T.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))"
33
+ ]
34
+ },
35
+ {
36
+ "cell_type": "code",
37
+ "execution_count": null,
38
+ "metadata": {},
39
+ "outputs": [],
40
+ "source": [
41
+ "PATH_TO_IMG = \"./notebooks/0409.jpg\"\n",
42
+ "GT = \"./notebooks/0409.png\"\n",
43
+ "SCRIBBLE = \"./notebooks/11965.png\""
44
+ ]
45
+ },
46
+ {
47
+ "cell_type": "code",
48
+ "execution_count": null,
49
+ "metadata": {},
50
+ "outputs": [],
51
+ "source": [
52
+ "img = Image.open(PATH_TO_IMG)\n",
53
+ "img = img.convert(\"RGB\")\n",
54
+ "img"
55
+ ]
56
+ },
57
+ {
58
+ "cell_type": "code",
59
+ "execution_count": null,
60
+ "metadata": {},
61
+ "outputs": [],
62
+ "source": [
63
+ "scr = Image.open(GT)\n",
64
+ "scr = scr.convert(\"P\")\n",
65
+ "scr"
66
+ ]
67
+ },
68
+ {
69
+ "cell_type": "code",
70
+ "execution_count": null,
71
+ "metadata": {},
72
+ "outputs": [],
73
+ "source": [
74
+ "try:\n",
75
+ " from torchvision.transforms import InterpolationMode\n",
76
+ "\n",
77
+ " BICUBIC = InterpolationMode.BICUBIC\n",
78
+ "except ImportError:\n",
79
+ " BICUBIC = Image.BICUBIC\n",
80
+ " \n",
81
+ "def _preprocess(img, img_size):\n",
82
+ " transform = T.Compose(\n",
83
+ " [\n",
84
+ " T.Resize(img_size, BICUBIC),\n",
85
+ " T.CenterCrop(img_size),\n",
86
+ " T.ToTensor(),\n",
87
+ " NORMALIZE\n",
88
+ " ]\n",
89
+ " )\n",
90
+ " return transform(img)"
91
+ ]
92
+ },
93
+ {
94
+ "cell_type": "code",
95
+ "execution_count": null,
96
+ "metadata": {},
97
+ "outputs": [],
98
+ "source": [
99
+ "img_t = _preprocess(img, 224)#[None,:,:,:]\n",
100
+ "inputs = img_t.to(\"cuda\")\n",
101
+ "inputs.shape"
102
+ ]
103
+ },
104
+ {
105
+ "cell_type": "code",
106
+ "execution_count": null,
107
+ "metadata": {},
108
+ "outputs": [],
109
+ "source": [
110
+ "scribble = scribble.to(\"cuda\")\n",
111
+ "scribble.shape"
112
+ ]
113
+ },
114
+ {
115
+ "cell_type": "code",
116
+ "execution_count": null,
117
+ "metadata": {},
118
+ "outputs": [],
119
+ "source": [
120
+ "m_i = inputs * scribble\n",
121
+ "m_i = m_i[None,:,:,:]\n",
122
+ "inputs = m_i.to(\"cuda\")\n",
123
+ "inputs.shape"
124
+ ]
125
+ },
126
+ {
127
+ "cell_type": "code",
128
+ "execution_count": null,
129
+ "metadata": {},
130
+ "outputs": [],
131
+ "source": [
132
+ "from datasets.utils import unnormalize\n",
133
+ "img_init = unnormalize(m_i)\n",
134
+ "img_init.shape"
135
+ ]
136
+ },
137
+ {
138
+ "cell_type": "code",
139
+ "execution_count": null,
140
+ "metadata": {},
141
+ "outputs": [],
142
+ "source": [
143
+ "import cv2\n",
144
+ "import numpy as np \n",
145
+ "\n",
146
+ "ten =(img_init.permute(1,2,0).detach().cpu().numpy())\n",
147
+ "ten=(ten*255).astype(np.uint8)\n",
148
+ "#ten=cv2.cvtColor(ten,cv2.COLOR_RGB2BGR)\n",
149
+ "ten.shape"
150
+ ]
151
+ },
152
+ {
153
+ "cell_type": "code",
154
+ "execution_count": null,
155
+ "metadata": {},
156
+ "outputs": [],
157
+ "source": [
158
+ "plt.imshow(ten)\n",
159
+ "plt.axis('off')\n",
160
+ "plt.savefig('masked_image.png', bbox_inches='tight', pad_inches=0)"
161
+ ]
162
+ },
163
+ {
164
+ "cell_type": "code",
165
+ "execution_count": null,
166
+ "metadata": {},
167
+ "outputs": [],
168
+ "source": []
169
+ },
170
+ {
171
+ "cell_type": "code",
172
+ "execution_count": null,
173
+ "metadata": {},
174
+ "outputs": [],
175
+ "source": [
176
+ "gt = Image.open(GT)\n",
177
+ "gt = gt.convert(\"P\")\n",
178
+ "gt"
179
+ ]
180
+ },
181
+ {
182
+ "cell_type": "code",
183
+ "execution_count": null,
184
+ "metadata": {},
185
+ "outputs": [],
186
+ "source": []
187
+ },
188
+ {
189
+ "cell_type": "code",
190
+ "execution_count": null,
191
+ "metadata": {},
192
+ "outputs": [],
193
+ "source": [
194
+ "try:\n",
195
+ " from torchvision.transforms import InterpolationMode\n",
196
+ "\n",
197
+ " BICUBIC = InterpolationMode.BICUBIC\n",
198
+ "except ImportError:\n",
199
+ " BICUBIC = Image.BICUBIC\n",
200
+ " \n",
201
+ "def _preprocess_scribble(img, img_size):\n",
202
+ " transform = T.Compose(\n",
203
+ " [\n",
204
+ " T.Resize(img_size, BICUBIC),\n",
205
+ " T.CenterCrop(img_size),\n",
206
+ " T.ToTensor(),\n",
207
+ " ]\n",
208
+ " )\n",
209
+ " return transform(img)"
210
+ ]
211
+ },
212
+ {
213
+ "cell_type": "code",
214
+ "execution_count": null,
215
+ "metadata": {},
216
+ "outputs": [],
217
+ "source": [
218
+ "scribble = _preprocess_scribble(scr, 224)\n",
219
+ "#scribble = (scribble > 0).float() # threshold to [0,1]\n",
220
+ "#scribble = torch.max(scribble) - scribble # inverted scribble"
221
+ ]
222
+ },
223
+ {
224
+ "cell_type": "code",
225
+ "execution_count": null,
226
+ "metadata": {},
227
+ "outputs": [],
228
+ "source": [
229
+ "scribble.shape"
230
+ ]
231
+ },
232
+ {
233
+ "cell_type": "code",
234
+ "execution_count": null,
235
+ "metadata": {},
236
+ "outputs": [],
237
+ "source": [
238
+ "import cv2\n",
239
+ "import numpy as np \n",
240
+ "\n",
241
+ "tens =(scribble.permute(1,2,0).detach().cpu().numpy())\n",
242
+ "tens=(tens*255).astype(np.uint8)\n",
243
+ "#ten=cv2.cvtColor(ten,cv2.COLOR_RGB2BGR)\n",
244
+ "tens.shape"
245
+ ]
246
+ },
247
+ {
248
+ "cell_type": "code",
249
+ "execution_count": null,
250
+ "metadata": {},
251
+ "outputs": [],
252
+ "source": [
253
+ "plt.imshow(tens, cmap='gray')\n",
254
+ "plt.axis('off')\n",
255
+ "plt.savefig('gt.png', bbox_inches='tight', pad_inches=0)"
256
+ ]
257
+ },
258
+ {
259
+ "cell_type": "code",
260
+ "execution_count": null,
261
+ "metadata": {},
262
+ "outputs": [],
263
+ "source": []
264
+ },
265
+ {
266
+ "cell_type": "code",
267
+ "execution_count": null,
268
+ "metadata": {},
269
+ "outputs": [],
270
+ "source": [
271
+ "masked_img_t = img * scribble"
272
+ ]
273
+ },
274
+ {
275
+ "cell_type": "code",
276
+ "execution_count": null,
277
+ "metadata": {},
278
+ "outputs": [],
279
+ "source": []
280
+ },
281
+ {
282
+ "cell_type": "code",
283
+ "execution_count": null,
284
+ "metadata": {},
285
+ "outputs": [],
286
+ "source": [
287
+ "model = FoundModel(vit_model=\"dino\",\n",
288
+ " vit_arch=\"vit_small\",\n",
289
+ " vit_patch_size=8,\n",
290
+ " enc_type_feats=\"k\",\n",
291
+ " bkg_type_feats=\"k\",\n",
292
+ " bkg_th=0.3)\n",
293
+ "\n",
294
+ "# Load weights\n",
295
+ "model.decoder_load_weights(\"./outputs/msl_a1.5_b1_g1_reg4-MSL-DUTS-TR-vit_small8/decoder_weights_niter500.pt\")\n",
296
+ "model.eval()"
297
+ ]
298
+ },
299
+ {
300
+ "cell_type": "code",
301
+ "execution_count": null,
302
+ "metadata": {},
303
+ "outputs": [],
304
+ "source": [
305
+ "# Forward step\n",
306
+ "with torch.no_grad():\n",
307
+ " preds, _, shape_f, att = model.forward_step(inputs, for_eval=True)\n",
308
+ "\n",
309
+ "# Apply FOUND\n",
310
+ "sigmoid = nn.Sigmoid()\n",
311
+ "h, w = img_t.shape[-2:]\n",
312
+ "preds_up = F.interpolate(\n",
313
+ " preds, scale_factor=model.vit_patch_size, mode=\"bilinear\", align_corners=False\n",
314
+ ")[..., :h, :w]\n",
315
+ "preds_up = (\n",
316
+ " (sigmoid(preds_up.detach()) > 0.5).squeeze(0).float()\n",
317
+ ")"
318
+ ]
319
+ },
320
+ {
321
+ "cell_type": "code",
322
+ "execution_count": null,
323
+ "metadata": {},
324
+ "outputs": [],
325
+ "source": [
326
+ "plt.imshow(preds_up.cpu().squeeze().numpy(), cmap='gray')\n",
327
+ "plt.axis('off')\n",
328
+ "plt.savefig('masked_pred.png', bbox_inches='tight', pad_inches=0)"
329
+ ]
330
+ },
331
+ {
332
+ "cell_type": "code",
333
+ "execution_count": null,
334
+ "metadata": {},
335
+ "outputs": [],
336
+ "source": [
337
+ "preds_up.shape"
338
+ ]
339
+ },
340
+ {
341
+ "cell_type": "code",
342
+ "execution_count": null,
343
+ "metadata": {},
344
+ "outputs": [],
345
+ "source": []
346
+ },
347
+ {
348
+ "cell_type": "code",
349
+ "execution_count": null,
350
+ "metadata": {},
351
+ "outputs": [],
352
+ "source": [
353
+ "def read_image(path):\n",
354
+ " image = cv2.imread(path, -1)\n",
355
+ " image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)\n",
356
+ " image = make_border(image)\n",
357
+ " return image\n",
358
+ "\n",
359
+ "\n",
360
+ "def make_border(im):\n",
361
+ " row, col = im.shape[:2]\n",
362
+ " bottom = im[row-2:row, 0:col]\n",
363
+ " mean = cv2.mean(bottom)[0]\n",
364
+ " bordersize = 5\n",
365
+ " border = cv2.copyMakeBorder(\n",
366
+ " im,\n",
367
+ " top=bordersize,\n",
368
+ " bottom=bordersize,\n",
369
+ " left=bordersize,\n",
370
+ " right=bordersize,\n",
371
+ " borderType=cv2.BORDER_CONSTANT,\n",
372
+ " value=[0, 0, 0]\n",
373
+ " )\n",
374
+ " return border"
375
+ ]
376
+ },
377
+ {
378
+ "cell_type": "code",
379
+ "execution_count": null,
380
+ "metadata": {},
381
+ "outputs": [],
382
+ "source": [
383
+ "img = read_image(\"./notebooks/scribble.png\")"
384
+ ]
385
+ },
386
+ {
387
+ "cell_type": "code",
388
+ "execution_count": null,
389
+ "metadata": {},
390
+ "outputs": [],
391
+ "source": [
392
+ "plt.imshow(img)\n",
393
+ "plt.axis('off')\n",
394
+ "plt.savefig('scribble.png', bbox_inches='tight', pad_inches=0)"
395
+ ]
396
+ },
397
+ {
398
+ "cell_type": "code",
399
+ "execution_count": null,
400
+ "metadata": {},
401
+ "outputs": [],
402
+ "source": []
403
+ },
404
+ {
405
+ "cell_type": "code",
406
+ "execution_count": null,
407
+ "metadata": {},
408
+ "outputs": [],
409
+ "source": []
410
+ }
411
+ ],
412
+ "metadata": {
413
+ "kernelspec": {
414
+ "display_name": "tarmak",
415
+ "language": "python",
416
+ "name": "python3"
417
+ },
418
+ "language_info": {
419
+ "codemirror_mode": {
420
+ "name": "ipython",
421
+ "version": 3
422
+ },
423
+ "file_extension": ".py",
424
+ "mimetype": "text/x-python",
425
+ "name": "python",
426
+ "nbconvert_exporter": "python",
427
+ "pygments_lexer": "ipython3",
428
+ "version": "3.8.18"
429
+ },
430
+ "orig_nbformat": 4
431
+ },
432
+ "nbformat": 4,
433
+ "nbformat_minor": 2
434
+ }
notebooks/graphs.ipynb ADDED
@@ -0,0 +1,249 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "code",
5
+ "execution_count": null,
6
+ "metadata": {},
7
+ "outputs": [],
8
+ "source": [
9
+ "%load_ext autoreload\n",
10
+ "%autoreload 2"
11
+ ]
12
+ },
13
+ {
14
+ "cell_type": "code",
15
+ "execution_count": null,
16
+ "metadata": {},
17
+ "outputs": [],
18
+ "source": [
19
+ "import pandas as pd\n",
20
+ "import numpy as np\n",
21
+ "import sys\n",
22
+ "import matplotlib.pyplot as plt\n",
23
+ "import seaborn as sns\n",
24
+ "import os \n",
25
+ "sns.set()\n",
26
+ "\n",
27
+ "%matplotlib inline\n",
28
+ "import warnings\n",
29
+ "warnings.filterwarnings('ignore')\n",
30
+ "\n",
31
+ "# https://abdalimran.github.io/2019-06-01/Drawing-multiple-ROC-Curves-in-a-single-plot"
32
+ ]
33
+ },
34
+ {
35
+ "cell_type": "code",
36
+ "execution_count": null,
37
+ "metadata": {},
38
+ "outputs": [],
39
+ "source": [
40
+ "#labels = ['Baseline', 'MaskSup']\n",
41
+ "labels = ['VOC07', 'VOC12', 'COCO20K']\n",
42
+ "\n",
43
+ "# VOC\n",
44
+ "auc = [71.7, 75.6, 62] # base\n",
45
+ "acc_nst = [72.7, 75.9, 64.0]\n",
46
+ "\n",
47
+ "# COCO\n",
48
+ "# auc = [54.2,36.0,48.4] # base\n",
49
+ "# acc_nst = [74.8,59.4,68.8]\n",
50
+ "\n",
51
+ "x = np.arange(len(labels)) # the label locations\n",
52
+ "dummy = np.arange(10)\n",
53
+ "\n",
54
+ "width = 0.35 #0.4 # the width of the bars\n",
55
+ "\n",
56
+ "\n",
57
+ "\n",
58
+ "fig, ax = plt.subplots()\n",
59
+ "\n",
60
+ "rects1 = ax.bar(x - width/2, auc, width, label='low masking', color='#E96479') # #FFAE6D\n",
61
+ "rects2 = ax.bar(x + width/2, acc_nst, width, label='high masking', color='#7DB9B6') # #9ED2C6\n",
62
+ "#rects211 = ax.bar(x + width/2 * 3.08, acc, width, label='CF1')\n",
63
+ "\n",
64
+ "#ax.set_ylabel('CorLoc (%)', fontsize=20)\n",
65
+ "#ax.set_title('Results')\n",
66
+ "ax.set_xticks(x)\n",
67
+ "ax.set_xticklabels(labels, rotation=0, fontsize=20)\n",
68
+ "\n",
69
+ "#for i in range(18):\n",
70
+ "# ax.get_xticklabels()[i].set_color(\"white\")\n",
71
+ "\n",
72
+ "#ax.set_ylim([30,80]) # coc\n",
73
+ "ax.set_ylim([60,80]) # voc\n",
74
+ "\n",
75
+ "#ax.legend(loc=\"upper left\", prop={'size': 14})\n",
76
+ "ax.grid(True)\n",
77
+ "#ax.patch.set_facecolor('white')\n",
78
+ "\n",
79
+ "def autolabel(rects):\n",
80
+ " \"\"\"Attach a text label above each bar in *rects*, displaying its height.\"\"\"\n",
81
+ " for rect in rects:\n",
82
+ " height = rect.get_height()\n",
83
+ " ax.annotate('{:.1f}'.format(height),\n",
84
+ " xy=(rect.get_x() + rect.get_width() / 2, height),\n",
85
+ " xytext=(0, 3), # 3 points vertical offset\n",
86
+ " textcoords=\"offset points\",\n",
87
+ " ha='center', va='bottom', rotation=0, fontsize=15)\n",
88
+ " #ax.set_ylim(ymin=1)\n",
89
+ " \n",
90
+ "\n",
91
+ "def autolabel_(rects):\n",
92
+ " \"\"\"Attach a text label above each bar in *rects*, displaying its height.\"\"\"\n",
93
+ " for rect in rects:\n",
94
+ " height = rect.get_height()\n",
95
+ " ax.annotate('{:.1f}'.format(height),\n",
96
+ " xy=(rect.get_x() + rect.get_width() / 2, height),\n",
97
+ " xytext=(0, 3), # 3 points vertical offset\n",
98
+ " textcoords=\"offset points\",\n",
99
+ " ha='center', va='bottom', rotation=0, fontsize=15)\n",
100
+ " #ax.set_ylim(ymin=1)\n",
101
+ "\n",
102
+ "\n",
103
+ "autolabel(rects1) # %\n",
104
+ "autolabel(rects2)\n",
105
+ "#autolabel_(rects211) # %\n",
106
+ "\n",
107
+ "fig.tight_layout()\n",
108
+ "fig.set_size_inches(12, 4, forward=True)\n",
109
+ "plt.title('Impact of masking (\\u2191)', loc='left', fontsize=25, color='gray', pad=12)\n",
110
+ "#plt.title('VOC2007 (\\u2191)', loc='left', fontsize=25, color='gray', pad=12)\n",
111
+ "plt.legend(loc='upper right', fontsize=18)\n",
112
+ "plt.savefig(\"../logs/masking_ablation.pdf\", bbox_inches='tight', pad_inches=0, dpi=300)\n",
113
+ "plt.show()"
114
+ ]
115
+ },
116
+ {
117
+ "cell_type": "code",
118
+ "execution_count": null,
119
+ "metadata": {},
120
+ "outputs": [],
121
+ "source": []
122
+ },
123
+ {
124
+ "cell_type": "code",
125
+ "execution_count": null,
126
+ "metadata": {},
127
+ "outputs": [],
128
+ "source": []
129
+ },
130
+ {
131
+ "cell_type": "code",
132
+ "execution_count": null,
133
+ "metadata": {},
134
+ "outputs": [],
135
+ "source": []
136
+ },
137
+ {
138
+ "cell_type": "code",
139
+ "execution_count": null,
140
+ "metadata": {},
141
+ "outputs": [],
142
+ "source": [
143
+ "#labels = ['Baseline', 'MaskSup']\n",
144
+ "labels = ['VOC07', 'VOC12', 'COCO20K']\n",
145
+ "\n",
146
+ "# VOC\n",
147
+ "auc_b = [71.6, 75.2, 61.8] # base\n",
148
+ "auc = [72.2, 75.5, 62.3] # base\n",
149
+ "acc_nst = [72.7, 75.9, 64.0]\n",
150
+ "\n",
151
+ "# COCO\n",
152
+ "# auc = [54.2,36.0,48.4] # base\n",
153
+ "# acc_nst = [74.8,59.4,68.8]\n",
154
+ "\n",
155
+ "x = np.arange(len(labels)) # the label locations\n",
156
+ "dummy = np.arange(10)\n",
157
+ "\n",
158
+ "width = 0.25 #0.4 # the width of the bars\n",
159
+ "\n",
160
+ "\n",
161
+ "\n",
162
+ "fig, ax = plt.subplots()\n",
163
+ "\n",
164
+ "rects1 = ax.bar(x - width/2, auc_b, width, label='Baseline', color='#E96479') # #FFAE6D\n",
165
+ "rects2 = ax.bar(x + width/2, auc, width, label='w/ MFP', color='#7DB9B6') # #9ED2C6\n",
166
+ "rects211 = ax.bar(x + width/2 * 3.08, acc_nst, width, label='w/ MFP + PCL', color='#FFAE6D')\n",
167
+ "\n",
168
+ "ax.set_ylabel('CorLoc (%)', fontsize=20)\n",
169
+ "#ax.set_title('Results')\n",
170
+ "ax.set_xticks(x)\n",
171
+ "ax.set_xticklabels(labels, rotation=0, fontsize=20)\n",
172
+ "\n",
173
+ "#for i in range(18):\n",
174
+ "# ax.get_xticklabels()[i].set_color(\"white\")\n",
175
+ "\n",
176
+ "#ax.set_ylim([30,80]) # coc\n",
177
+ "ax.set_ylim([60,80]) # voc\n",
178
+ "\n",
179
+ "#ax.legend(loc=\"upper left\", prop={'size': 14})\n",
180
+ "ax.grid(True)\n",
181
+ "#ax.patch.set_facecolor('white')\n",
182
+ "\n",
183
+ "def autolabel(rects):\n",
184
+ " \"\"\"Attach a text label above each bar in *rects*, displaying its height.\"\"\"\n",
185
+ " for rect in rects:\n",
186
+ " height = rect.get_height()\n",
187
+ " ax.annotate('{:.1f}'.format(height),\n",
188
+ " xy=(rect.get_x() + rect.get_width() / 2, height),\n",
189
+ " xytext=(0, 3), # 3 points vertical offset\n",
190
+ " textcoords=\"offset points\",\n",
191
+ " ha='center', va='bottom', rotation=0, fontsize=15)\n",
192
+ " #ax.set_ylim(ymin=1)\n",
193
+ " \n",
194
+ "\n",
195
+ "def autolabel_(rects):\n",
196
+ " \"\"\"Attach a text label above each bar in *rects*, displaying its height.\"\"\"\n",
197
+ " for rect in rects:\n",
198
+ " height = rect.get_height()\n",
199
+ " ax.annotate('{:.1f}'.format(height),\n",
200
+ " xy=(rect.get_x() + rect.get_width() / 2, height),\n",
201
+ " xytext=(0, 3), # 3 points vertical offset\n",
202
+ " textcoords=\"offset points\",\n",
203
+ " ha='center', va='bottom', rotation=0, fontsize=15)\n",
204
+ " #ax.set_ylim(ymin=1)\n",
205
+ "\n",
206
+ "\n",
207
+ "autolabel(rects1) # %\n",
208
+ "autolabel(rects2)\n",
209
+ "autolabel_(rects211) # %\n",
210
+ "\n",
211
+ "fig.tight_layout()\n",
212
+ "fig.set_size_inches(12, 4, forward=True)\n",
213
+ "plt.title('Effectiveness of MFP and PCL (\\u2191)', loc='left', fontsize=25, color='gray', pad=12)\n",
214
+ "#plt.title('VOC2007 (\\u2191)', loc='left', fontsize=25, color='gray', pad=12)\n",
215
+ "plt.legend(loc='upper right', fontsize=18)\n",
216
+ "plt.savefig(\"../logs/msl_ablation.pdf\", bbox_inches='tight', pad_inches=0, dpi=300)\n",
217
+ "plt.show()"
218
+ ]
219
+ },
220
+ {
221
+ "cell_type": "code",
222
+ "execution_count": null,
223
+ "metadata": {},
224
+ "outputs": [],
225
+ "source": []
226
+ }
227
+ ],
228
+ "metadata": {
229
+ "kernelspec": {
230
+ "display_name": "bdstreets",
231
+ "language": "python",
232
+ "name": "python3"
233
+ },
234
+ "language_info": {
235
+ "codemirror_mode": {
236
+ "name": "ipython",
237
+ "version": 3
238
+ },
239
+ "file_extension": ".py",
240
+ "mimetype": "text/x-python",
241
+ "name": "python",
242
+ "nbconvert_exporter": "python",
243
+ "pygments_lexer": "ipython3",
244
+ "version": "3.8.17"
245
+ }
246
+ },
247
+ "nbformat": 4,
248
+ "nbformat_minor": 2
249
+ }
notebooks/visualize.ipynb ADDED
@@ -0,0 +1,571 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "code",
5
+ "execution_count": null,
6
+ "metadata": {},
7
+ "outputs": [],
8
+ "source": [
9
+ "%load_ext autoreload\n",
10
+ "%autoreload 2"
11
+ ]
12
+ },
13
+ {
14
+ "cell_type": "code",
15
+ "execution_count": null,
16
+ "metadata": {},
17
+ "outputs": [],
18
+ "source": [
19
+ "import os,sys,inspect\n",
20
+ "sys.path.insert(0,\"..\")\n",
21
+ "\n",
22
+ "import matplotlib.pyplot as plt\n",
23
+ "from matplotlib import rc\n",
24
+ "import glob\n",
25
+ "\n",
26
+ "macos = False\n",
27
+ "if macos == True:\n",
28
+ " rc('font',**{'family':'sans-serif','sans-serif':['Computer Modern Roman']})\n",
29
+ " rc('text', usetex=True)\n",
30
+ "\n",
31
+ "# Font Size\n",
32
+ "import matplotlib\n",
33
+ "font = {'family' : 'DejaVu Sans',\n",
34
+ " 'weight' : 'bold',\n",
35
+ " 'size' : 30}\n",
36
+ "\n",
37
+ "import cv2\n",
38
+ "import numpy as np\n",
39
+ "import string\n",
40
+ "import random"
41
+ ]
42
+ },
43
+ {
44
+ "cell_type": "code",
45
+ "execution_count": null,
46
+ "metadata": {},
47
+ "outputs": [],
48
+ "source": [
49
+ "def visualize(idx, **images):\n",
50
+ " \"\"\"Plot images in one row.\"\"\" \n",
51
+ " n = len(images)\n",
52
+ " fig = plt.figure(figsize=(60, 40))\n",
53
+ " for i, (name, image) in enumerate(images.items()):\n",
54
+ " plt.subplot(1, n, i + 1)\n",
55
+ " plt.xticks([])\n",
56
+ " plt.yticks([])\n",
57
+ " #if idx==0:\n",
58
+ " plt.title(' '.join(name.split('_')).lower(), fontsize=40)\n",
59
+ " if i ==0:\n",
60
+ " w,h = (1,25)\n",
61
+ " fs = 1.0\n",
62
+ " color = (0,0,0)\n",
63
+ " #color = (255,255,255)\n",
64
+ " font = cv2.FONT_HERSHEY_SIMPLEX #FONT_HERSHEY_DUPLEX #press tab for different operations\n",
65
+ " cv2.putText(image, str(idx), (w,h), font, fs, color, 1, cv2.LINE_AA)\n",
66
+ " if i !=0:\n",
67
+ " #plt.imshow(image[:,:,0], cmap='magma')\n",
68
+ " plt.imshow(image, cmap='gray')\n",
69
+ " else:\n",
70
+ " plt.imshow(image, cmap='gray')\n",
71
+ " plt.axis(\"off\")\n",
72
+ " #plt.tight_layout()\n",
73
+ " plt.savefig(\"../outputs/visualizations/duts-te/compare-preds/{}.png\".format(idx), facecolor=\"white\", bbox_inches = 'tight')\n",
74
+ " plt.show()\n",
75
+ " \n",
76
+ " \n",
77
+ "def make_dataset(dir):\n",
78
+ " images = []\n",
79
+ " assert os.path.isdir(dir), '%s is not a valid directory' % dir\n",
80
+ "\n",
81
+ " f = dir.split('/')[-1].split('_')[-1]\n",
82
+ " #print (dir, f)\n",
83
+ " dirs= os.listdir(dir)\n",
84
+ " for img in dirs:\n",
85
+ "\n",
86
+ " path = os.path.join(dir, img)\n",
87
+ " #print(path)\n",
88
+ " images.append(path)\n",
89
+ " return images\n",
90
+ "\n",
91
+ "# def make_dataset(dir):\n",
92
+ "# images = []\n",
93
+ "# assert os.path.isdir(dir), '%s is not a valid directory' % dir\n",
94
+ "\n",
95
+ "# # f = dir.split('/')[-1].split('_')[-1]\n",
96
+ "# # #print (dir, f)\n",
97
+ "# # dirs= os.listdir(dir)\n",
98
+ "# # for img in dirs:\n",
99
+ "\n",
100
+ "# # path = os.path.join(dir, img)\n",
101
+ "# # #print(path)\n",
102
+ "# # images.append(path)\n",
103
+ "# images = natsorted(glob.glob(dir+ \"/\" + \"/*.png\"))\n",
104
+ "# return images\n",
105
+ "\n",
106
+ "def read_image(path):\n",
107
+ " image = cv2.imread(path, -1)\n",
108
+ " image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)\n",
109
+ " image = Image.fromarray(np.uint8(image)).convert('RGB')\n",
110
+ " image = resize_center_crop(image)\n",
111
+ " image = make_border(image)\n",
112
+ " return image\n",
113
+ "\n",
114
+ "\n",
115
+ "def make_border(im):\n",
116
+ " row, col = im.shape[:2]\n",
117
+ " bottom = im[row-2:row, 0:col]\n",
118
+ " mean = cv2.mean(bottom)[0]\n",
119
+ " bordersize = 5\n",
120
+ " border = cv2.copyMakeBorder(\n",
121
+ " im,\n",
122
+ " top=bordersize,\n",
123
+ " bottom=bordersize,\n",
124
+ " left=bordersize,\n",
125
+ " right=bordersize,\n",
126
+ " borderType=cv2.BORDER_CONSTANT,\n",
127
+ " value=[0, 0, 0]\n",
128
+ " )\n",
129
+ " return border\n",
130
+ "\n",
131
+ "from PIL import Image\n",
132
+ "from torchvision.transforms import Compose, Resize, CenterCrop, ToTensor, Normalize\n",
133
+ "\n",
134
+ "try:\n",
135
+ " from torchvision.transforms import InterpolationMode\n",
136
+ "\n",
137
+ " BICUBIC = InterpolationMode.BICUBIC\n",
138
+ "except ImportError:\n",
139
+ " BICUBIC = Image.BICUBIC\n",
140
+ "\n",
141
+ "\n",
142
+ "def _convert_image_to_rgb(image):\n",
143
+ " return image.convert(\"RGB\")\n",
144
+ "\n",
145
+ "\n",
146
+ "def resize_center_crop(img):\n",
147
+ " \"\"\" \n",
148
+ " Load and resize an image to a desired size.\n",
149
+ "\n",
150
+ " Arguments:\n",
151
+ " img (PIL image): Image to load and resize\n",
152
+ "\n",
153
+ " Returns:\n",
154
+ " img (np.array): Resized and cropped image\n",
155
+ "\n",
156
+ " Examples:\n",
157
+ " >>> img = resize_center_crop(img)\n",
158
+ " \"\"\"\n",
159
+ "\n",
160
+ " if type(img) == str:\n",
161
+ " img = Image.open(img)\n",
162
+ "\n",
163
+ " transform = Compose(\n",
164
+ " [\n",
165
+ " Resize(224, BICUBIC),\n",
166
+ " CenterCrop(224),\n",
167
+ " _convert_image_to_rgb,\n",
168
+ " # ToTensor(),\n",
169
+ " # Normalize(\n",
170
+ " # (0.5, 0.5, 0.5),\n",
171
+ " # (0.5, 0.5, 0.5),\n",
172
+ " # ),\n",
173
+ " ]\n",
174
+ " )\n",
175
+ " img = transform(img)\n",
176
+ " img = np.array(img)\n",
177
+ " return img\n",
178
+ "\n",
179
+ "def read_image_(path):\n",
180
+ " image = cv2.imread(path, -1)\n",
181
+ " image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)\n",
182
+ " image = cv2.resize(image, (192, 256))\n",
183
+ " return image"
184
+ ]
185
+ },
186
+ {
187
+ "cell_type": "code",
188
+ "execution_count": null,
189
+ "metadata": {},
190
+ "outputs": [],
191
+ "source": [
192
+ "# # ECSSD\n",
193
+ "\n",
194
+ "# # Images and GT\n",
195
+ "\n",
196
+ "# GT = \"../outputs/visualizations/ecssd/gts\"\n",
197
+ "# IMG = \"../datasets_local/ECSSD/images/\"\n",
198
+ "# GTS = [os.path.join(GT, x) for x in os.listdir(GT)]\n",
199
+ "# IMGS = [os.path.join(IMG, x) for x in os.listdir(IMG)]\n",
200
+ "\n",
201
+ "# # Algo\n",
202
+ "# algo1 = \"../outputs/visualizations/ecssd/found-MSL-DUTS-TR-vit_small8_ECSSD/\"\n",
203
+ "# ours = \"../outputs/visualizations/ecssd/msl_a1.5_b1_g1_reg4-MSL-DUTS-TR-vit_small8_ECSSD/\"\n",
204
+ "\n",
205
+ "# algo1 = [os.path.join(algo1, x) for x in os.listdir(algo1)]\n",
206
+ "# ours = [os.path.join(ours, x) for x in os.listdir(ours)]\n",
207
+ "\n",
208
+ "# print(len(GTS), len(IMGS))\n",
209
+ "# print(ours[:3])\n",
210
+ "\n",
211
+ "# i = 0\n",
212
+ "# for num in range(len(IMGS)):\n",
213
+ "# visualize(i, \n",
214
+ "# image=read_image(IMGS[num]),\n",
215
+ "# found_method=read_image(algo1[num]),\n",
216
+ "# our_method=read_image(ours[num]),\n",
217
+ "# gt=read_image(GTS[num]))\n",
218
+ "# i+=1"
219
+ ]
220
+ },
221
+ {
222
+ "cell_type": "code",
223
+ "execution_count": null,
224
+ "metadata": {},
225
+ "outputs": [],
226
+ "source": [
227
+ "# # DUT_OMRON\n",
228
+ "\n",
229
+ "# # Images and GT\n",
230
+ "\n",
231
+ "# GT = \"../outputs/visualizations/dut-omron/gts\"\n",
232
+ "# IMG = \"../datasets_local/DUT-OMRON/DUT-OMRON-image/\"\n",
233
+ "# GTS = [os.path.join(GT, x) for x in os.listdir(GT)]\n",
234
+ "# IMGS = [os.path.join(IMG, x) for x in os.listdir(IMG)]\n",
235
+ "\n",
236
+ "# # Algo\n",
237
+ "# algo1 = \"../outputs/visualizations/dut-omron/found-MSL-DUTS-TR-vit_small8_DUT-OMRON/\"\n",
238
+ "# ours = \"../outputs/visualizations/dut-omron/msl_a1.5_b1_g1_reg4-MSL-DUTS-TR-vit_small8_DUT-OMRON/\"\n",
239
+ "\n",
240
+ "# algo1 = [os.path.join(algo1, x) for x in os.listdir(algo1)]\n",
241
+ "# ours = [os.path.join(ours, x) for x in os.listdir(ours)]\n",
242
+ "\n",
243
+ "# print(len(GTS), len(IMGS))\n",
244
+ "# print(ours[:3])\n",
245
+ "\n",
246
+ "# i = 0\n",
247
+ "# for num in range(len(IMGS)):\n",
248
+ "# visualize(i, \n",
249
+ "# image=read_image(IMGS[num]),\n",
250
+ "# found_method=read_image(algo1[num]),\n",
251
+ "# our_method=read_image(ours[num]),\n",
252
+ "# gt=read_image(GTS[num]))\n",
253
+ "# i+=1"
254
+ ]
255
+ },
256
+ {
257
+ "cell_type": "code",
258
+ "execution_count": null,
259
+ "metadata": {},
260
+ "outputs": [],
261
+ "source": [
262
+ "# # DUT-TE\n",
263
+ "\n",
264
+ "# # Images and GT\n",
265
+ "\n",
266
+ "# GT = \"../outputs/visualizations/duts-te/gts\"\n",
267
+ "# IMG = \"../datasets_local/DUTS-TE/DUTS-TE-Image/\"\n",
268
+ "# GTS = [os.path.join(GT, x) for x in os.listdir(GT)]\n",
269
+ "# IMGS = [os.path.join(IMG, x) for x in os.listdir(IMG)]\n",
270
+ "\n",
271
+ "# # Algo\n",
272
+ "# algo1 = \"../outputs/visualizations/duts-te/found-MSL-DUTS-TR-vit_small8_DUTS-TE/\"\n",
273
+ "# ours = \"../outputs/visualizations/duts-te/msl_a1.5_b1_g1_reg4-MSL-DUTS-TR-vit_small8_DUTS-TE/\"\n",
274
+ "\n",
275
+ "# algo1 = [os.path.join(algo1, x) for x in os.listdir(algo1)]\n",
276
+ "# ours = [os.path.join(ours, x) for x in os.listdir(ours)]\n",
277
+ "\n",
278
+ "# print(len(GTS), len(IMGS))\n",
279
+ "# print(ours[:3])\n",
280
+ "\n",
281
+ "# i = 0\n",
282
+ "# for num in range(len(IMGS)):\n",
283
+ "# visualize(i, \n",
284
+ "# image=read_image(IMGS[num]),\n",
285
+ "# found_method=read_image(algo1[num]),\n",
286
+ "# our_method=read_image(ours[num]),\n",
287
+ "# gt=read_image(GTS[num]))\n",
288
+ "# i+=1"
289
+ ]
290
+ },
291
+ {
292
+ "cell_type": "code",
293
+ "execution_count": null,
294
+ "metadata": {},
295
+ "outputs": [],
296
+ "source": [
297
+ "# GT\n",
298
+ "ECSS_GT = \"../outputs/visualizations/ecssd/gts\"\n",
299
+ "ECSS_IMG = \"../datasets_local/ECSSD/images/\"\n",
300
+ "ECSS_GTS = [os.path.join(ECSS_GT, x) for x in os.listdir(ECSS_GT)]\n",
301
+ "ECSS_IMGS = [os.path.join(ECSS_IMG, x) for x in os.listdir(ECSS_IMG)]\n",
302
+ "# Pred\n",
303
+ "ECSS_algo1 = \"../outputs/visualizations/ecssd/found-MSL-DUTS-TR-vit_small8_ECSSD/\"\n",
304
+ "ECSS_ours = \"../outputs/visualizations/ecssd/msl_a1.5_b1_g1_reg4-MSL-DUTS-TR-vit_small8_ECSSD/\"\n",
305
+ "ECSS_algo1 = [os.path.join(ECSS_algo1, x) for x in os.listdir(ECSS_algo1)]\n",
306
+ "ECSS_ours = [os.path.join(ECSS_ours, x) for x in os.listdir(ECSS_ours)]\n"
307
+ ]
308
+ },
309
+ {
310
+ "cell_type": "code",
311
+ "execution_count": null,
312
+ "metadata": {},
313
+ "outputs": [],
314
+ "source": [
315
+ "# GT\n",
316
+ "DUT_OM_GT = \"../outputs/visualizations/dut-omron/gts\"\n",
317
+ "DUT_OM_IMG = \"../datasets_local/DUT-OMRON/DUT-OMRON-image/\"\n",
318
+ "DUT_OM_GTS = [os.path.join(DUT_OM_GT, x) for x in os.listdir(DUT_OM_GT)]\n",
319
+ "DUT_OM_IMGS = [os.path.join(DUT_OM_IMG, x) for x in os.listdir(DUT_OM_IMG)]\n",
320
+ "\n",
321
+ "# Pred\n",
322
+ "DUT_OM_algo1 = \"../outputs/visualizations/dut-omron/found-MSL-DUTS-TR-vit_small8_DUT-OMRON/\"\n",
323
+ "DUT_OM_ours = \"../outputs/visualizations/dut-omron/msl_a1.5_b1_g1_reg4-MSL-DUTS-TR-vit_small8_DUT-OMRON/\"\n",
324
+ "DUT_OM_algo1 = [os.path.join(DUT_OM_algo1, x) for x in os.listdir(DUT_OM_algo1)]\n",
325
+ "DUT_OM_ours = [os.path.join(DUT_OM_ours, x) for x in os.listdir(DUT_OM_ours)]"
326
+ ]
327
+ },
328
+ {
329
+ "cell_type": "code",
330
+ "execution_count": null,
331
+ "metadata": {},
332
+ "outputs": [],
333
+ "source": [
334
+ "DUT_GT = \"../outputs/visualizations/duts-te/gts\"\n",
335
+ "DUT_IMG = \"../datasets_local/DUTS-TE/DUTS-TE-Image/\"\n",
336
+ "DUT_GTS = [os.path.join(DUT_GT, x) for x in os.listdir(DUT_GT)]\n",
337
+ "DUT_IMGS = [os.path.join(DUT_IMG, x) for x in os.listdir(DUT_IMG)]\n",
338
+ "\n",
339
+ "# Pred\n",
340
+ "DUT_algo1 = \"../outputs/visualizations/duts-te/found-MSL-DUTS-TR-vit_small8_DUTS-TE/\"\n",
341
+ "DUT_ours = \"../outputs/visualizations/duts-te/msl_a1.5_b1_g1_reg4-MSL-DUTS-TR-vit_small8_DUTS-TE/\"\n",
342
+ "DUT_algo1 = [os.path.join(DUT_algo1, x) for x in os.listdir(DUT_algo1)]\n",
343
+ "DUT_ours = [os.path.join(DUT_ours, x) for x in os.listdir(DUT_ours)]\n"
344
+ ]
345
+ },
346
+ {
347
+ "cell_type": "code",
348
+ "execution_count": null,
349
+ "metadata": {},
350
+ "outputs": [],
351
+ "source": [
352
+ "\n",
353
+ "# ECSSD -\n",
354
+ "# \t52, 132, 147 - over segmentation, fine details\n",
355
+ "# \t353, 658, 780 - reflection of shiny surface and water\n",
356
+ "# 432, 825, 835, 988 - noisy \n",
357
+ "# 59 (bee) - complex background\n",
358
+ "\n",
359
+ "# DUT-OMRON\n",
360
+ "# \t1, 14 - over segmentation\n",
361
+ "# \t119, 365, 439, 440, 1238 - noisy\n",
362
+ "# 1168, 1461 - segment other non-salient objects/parts\n",
363
+ "# 1388 - fails in complex background\n",
364
+ "# 1398 - small objects\n",
365
+ "# 1973 - dark scenes\n",
366
+ "\n",
367
+ "# DUTS-TE\n",
368
+ "# \t46, 698, 1712 - segment other non-salient objects/parts\n",
369
+ "# \t260 - small objects \n",
370
+ "# 776, 1255 - over segmentation\n",
371
+ "# \t683, 830, 1465 - noisy\n",
372
+ "# \t719, 1470 - reflection of water\n",
373
+ "\n",
374
+ "# 52, 132, 147, 353, 658, 780, - oversegment, reflection of shiny surface and water\n",
375
+ "# 1388, 1398, 1972, 1168, 1461, 440 - fails in complex background, small objects, dark scenes, segment non-salient objects, noisy\n",
376
+ "# 260, 719, 1470, 683, 830, 1465 - small objects, reflection of water, noisy predictions\n",
377
+ "\n",
378
+ "# idxs = [52, 59, 147, 353, 658, 780, 1388, 1398, 1973, 1168, 1461, 440, 260, 719, 1470, 683, 830, 1465]\n",
379
+ "\n",
380
+ "\n",
381
+ "\n",
382
+ "\n",
383
+ "# ECSSD -\n",
384
+ "# , 132, - over segmentation, fine details\n",
385
+ "# ,, - reflection of shiny surface and water\n",
386
+ "# 432, 825, 835, 988 - noisy \n",
387
+ "# 59 (bee) - complex background\n",
388
+ "\n",
389
+ "# DUT-OMRON\n",
390
+ "# \t1, 14 - over segmentation\n",
391
+ "# \t119, 365, 439,, 1238 - noisy\n",
392
+ "#, - segment other non-salient objects/parts\n",
393
+ "# - fails in complex background\n",
394
+ "# - small objects\n",
395
+ "# - dark scenes\n",
396
+ "\n",
397
+ "# DUTS-TE\n",
398
+ "# \t46, 698, 1712 - segment other non-salient objects/parts\n",
399
+ "# - small objects \n",
400
+ "# 776, 1255 - over segmentation\n",
401
+ "# ,, - noisy\n",
402
+ "# , - reflection of water\n",
403
+ "\n",
404
+ "idxs = [132,432,825,835,988,59,1,14,119,365,439,1238,46,698,1712,776,1255,4000]"
405
+ ]
406
+ },
407
+ {
408
+ "cell_type": "code",
409
+ "execution_count": null,
410
+ "metadata": {},
411
+ "outputs": [],
412
+ "source": [
413
+ "rows = int(len(idxs) / 3)\n",
414
+ "rows, len(idxs)"
415
+ ]
416
+ },
417
+ {
418
+ "cell_type": "code",
419
+ "execution_count": null,
420
+ "metadata": {},
421
+ "outputs": [],
422
+ "source": [
423
+ "rows = int(len(idxs) / 3)\n",
424
+ "cols = 12\n",
425
+ "fig, axarr = plt.subplots(rows, cols, figsize=(30, 15), constrained_layout=True)\n",
426
+ "\n",
427
+ "\n",
428
+ "alphabet_string = string.ascii_lowercase\n",
429
+ "alphabet_list = list(alphabet_string)\n",
430
+ "\n",
431
+ "v = 0\n",
432
+ "for r in range(rows):\n",
433
+ " if r == 0 or r == 1:\n",
434
+ " print(v, r)\n",
435
+ " a=read_image(ECSS_IMGS[idxs[v+r]])\n",
436
+ " b=read_image(ECSS_algo1[idxs[v+r]])\n",
437
+ "\n",
438
+ " c=read_image(ECSS_ours[idxs[v+r]])\n",
439
+ " d=read_image(ECSS_GTS[idxs[v+r]])\n",
440
+ "\n",
441
+ " e=read_image(ECSS_IMGS[idxs[v+r+1]])\n",
442
+ " f=read_image(ECSS_algo1[idxs[v+r+1]])\n",
443
+ "\n",
444
+ " g=read_image(ECSS_ours[idxs[v+r+1]])\n",
445
+ " h=read_image(ECSS_GTS[idxs[v+r+1]])\n",
446
+ "\n",
447
+ " i=read_image(ECSS_IMGS[idxs[v+r+2]])\n",
448
+ " j=read_image(ECSS_algo1[idxs[v+r+2]])\n",
449
+ "\n",
450
+ " k=read_image(ECSS_ours[idxs[v+r+2]])\n",
451
+ " l=read_image(ECSS_GTS[idxs[v+r+2]])\n",
452
+ "\n",
453
+ " if r == 2 or r == 3:\n",
454
+ " print(v, r)\n",
455
+ " a=read_image(DUT_OM_IMGS[idxs[v+r]])\n",
456
+ " b=read_image(DUT_OM_algo1[idxs[v+r]])\n",
457
+ "\n",
458
+ " c=read_image(DUT_OM_ours[idxs[v+r]])\n",
459
+ " d=read_image(DUT_OM_GTS[idxs[v+r]])\n",
460
+ "\n",
461
+ " e=read_image(DUT_OM_IMGS[idxs[v+r+1]])\n",
462
+ " f=read_image(DUT_OM_algo1[idxs[v+r+1]])\n",
463
+ "\n",
464
+ " g=read_image(DUT_OM_ours[idxs[v+r+1]])\n",
465
+ " h=read_image(DUT_OM_GTS[idxs[v+r+1]])\n",
466
+ "\n",
467
+ " i=read_image(DUT_OM_IMGS[idxs[v+r+2]])\n",
468
+ " j=read_image(DUT_OM_algo1[idxs[v+r+2]])\n",
469
+ "\n",
470
+ " k=read_image(DUT_OM_ours[idxs[v+r+2]])\n",
471
+ " l=read_image(DUT_OM_algo1[idxs[v+r+2]])\n",
472
+ "\n",
473
+ " if r == 4 or r == 5:\n",
474
+ " print(v, r)\n",
475
+ " a=read_image(DUT_IMGS[idxs[v+r]])\n",
476
+ " b=read_image(DUT_algo1[idxs[v+r]])\n",
477
+ "\n",
478
+ " c=read_image(DUT_ours[idxs[v+r]])\n",
479
+ " d=read_image(DUT_GTS[idxs[v+r]])\n",
480
+ "\n",
481
+ " e=read_image(DUT_IMGS[idxs[v+r+1]])\n",
482
+ " f=read_image(DUT_algo1[idxs[v+r+1]])\n",
483
+ "\n",
484
+ " g=read_image(DUT_ours[idxs[v+r+1]])\n",
485
+ " h=read_image(DUT_GTS[idxs[v+r+1]])\n",
486
+ "\n",
487
+ " i=read_image(DUT_IMGS[idxs[v+r+2]])\n",
488
+ " j=read_image(DUT_algo1[idxs[v+r+2]])\n",
489
+ "\n",
490
+ " k=read_image(DUT_ours[idxs[v+r+2]])\n",
491
+ " l=read_image(DUT_GTS[idxs[v+r+2]])\n",
492
+ "\n",
493
+ " v+=2\n",
494
+ " \n",
495
+ " images = [a,b,c,d,e,f,g,h,i,j,k,l]\n",
496
+ " \n",
497
+ " captions = [\"Image\", \"FOUND\", \"Ours\", \"Ground Truth\", \n",
498
+ " \"Image\", \"FOUND\", \"Ours\", \"Ground Truth\",\n",
499
+ " \"Image\", \"FOUND\", \"Ours\", \"Ground Truth\"]\n",
500
+ " \n",
501
+ " for c in range(cols):\n",
502
+ " axarr[r, c].imshow(images[c], cmap='gray')\n",
503
+ " axarr[r, c].axis(\"off\")\n",
504
+ " axarr[r, c].set_aspect('equal') \n",
505
+ " if r==0:\n",
506
+ " axarr[r, c].set_title(captions[c], fontsize=25)\n",
507
+ "\n",
508
+ "plt.savefig(\"../logs/compare_predictions_ext.pdf\", facecolor=\"white\", bbox_inches = 'tight', dpi=300)"
509
+ ]
510
+ },
511
+ {
512
+ "cell_type": "code",
513
+ "execution_count": null,
514
+ "metadata": {},
515
+ "outputs": [],
516
+ "source": [
517
+ "stacks = np.hstack([a,b,c_])\n",
518
+ "stacks.shape"
519
+ ]
520
+ },
521
+ {
522
+ "cell_type": "code",
523
+ "execution_count": null,
524
+ "metadata": {},
525
+ "outputs": [],
526
+ "source": [
527
+ "plt.imshow(stacks)\n",
528
+ "plt.axis(\"off\")\n",
529
+ "plt.savefig(\"../logs/failures.pdf\", facecolor=\"white\", bbox_inches = 'tight', dpi=300)"
530
+ ]
531
+ },
532
+ {
533
+ "cell_type": "code",
534
+ "execution_count": null,
535
+ "metadata": {},
536
+ "outputs": [],
537
+ "source": [
538
+ "a.shape, b.shape, c.shape"
539
+ ]
540
+ },
541
+ {
542
+ "cell_type": "code",
543
+ "execution_count": null,
544
+ "metadata": {},
545
+ "outputs": [],
546
+ "source": []
547
+ }
548
+ ],
549
+ "metadata": {
550
+ "kernelspec": {
551
+ "display_name": "uobjl",
552
+ "language": "python",
553
+ "name": "python3"
554
+ },
555
+ "language_info": {
556
+ "codemirror_mode": {
557
+ "name": "ipython",
558
+ "version": 3
559
+ },
560
+ "file_extension": ".py",
561
+ "mimetype": "text/x-python",
562
+ "name": "python",
563
+ "nbconvert_exporter": "python",
564
+ "pygments_lexer": "ipython3",
565
+ "version": "3.8.18"
566
+ },
567
+ "orig_nbformat": 4
568
+ },
569
+ "nbformat": 4,
570
+ "nbformat_minor": 2
571
+ }
outputs/VOC_000030-peekaboo.png ADDED