osimeoni commited on
Commit
25cae60
1 Parent(s): 907a760

FOUND - second

Browse files
LICENSE ADDED
@@ -0,0 +1,201 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Apache License
2
+ Version 2.0, January 2004
3
+ http://www.apache.org/licenses/
4
+
5
+ TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
6
+
7
+ 1. Definitions.
8
+
9
+ "License" shall mean the terms and conditions for use, reproduction,
10
+ and distribution as defined by Sections 1 through 9 of this document.
11
+
12
+ "Licensor" shall mean the copyright owner or entity authorized by
13
+ the copyright owner that is granting the License.
14
+
15
+ "Legal Entity" shall mean the union of the acting entity and all
16
+ other entities that control, are controlled by, or are under common
17
+ control with that entity. For the purposes of this definition,
18
+ "control" means (i) the power, direct or indirect, to cause the
19
+ direction or management of such entity, whether by contract or
20
+ otherwise, or (ii) ownership of fifty percent (50%) or more of the
21
+ outstanding shares, or (iii) beneficial ownership of such entity.
22
+
23
+ "You" (or "Your") shall mean an individual or Legal Entity
24
+ exercising permissions granted by this License.
25
+
26
+ "Source" form shall mean the preferred form for making modifications,
27
+ including but not limited to software source code, documentation
28
+ source, and configuration files.
29
+
30
+ "Object" form shall mean any form resulting from mechanical
31
+ transformation or translation of a Source form, including but
32
+ not limited to compiled object code, generated documentation,
33
+ and conversions to other media types.
34
+
35
+ "Work" shall mean the work of authorship, whether in Source or
36
+ Object form, made available under the License, as indicated by a
37
+ copyright notice that is included in or attached to the work
38
+ (an example is provided in the Appendix below).
39
+
40
+ "Derivative Works" shall mean any work, whether in Source or Object
41
+ form, that is based on (or derived from) the Work and for which the
42
+ editorial revisions, annotations, elaborations, or other modifications
43
+ represent, as a whole, an original work of authorship. For the purposes
44
+ of this License, Derivative Works shall not include works that remain
45
+ separable from, or merely link (or bind by name) to the interfaces of,
46
+ the Work and Derivative Works thereof.
47
+
48
+ "Contribution" shall mean any work of authorship, including
49
+ the original version of the Work and any modifications or additions
50
+ to that Work or Derivative Works thereof, that is intentionally
51
+ submitted to Licensor for inclusion in the Work by the copyright owner
52
+ or by an individual or Legal Entity authorized to submit on behalf of
53
+ the copyright owner. For the purposes of this definition, "submitted"
54
+ means any form of electronic, verbal, or written communication sent
55
+ to the Licensor or its representatives, including but not limited to
56
+ communication on electronic mailing lists, source code control systems,
57
+ and issue tracking systems that are managed by, or on behalf of, the
58
+ Licensor for the purpose of discussing and improving the Work, but
59
+ excluding communication that is conspicuously marked or otherwise
60
+ designated in writing by the copyright owner as "Not a Contribution."
61
+
62
+ "Contributor" shall mean Licensor and any individual or Legal Entity
63
+ on behalf of whom a Contribution has been received by Licensor and
64
+ subsequently incorporated within the Work.
65
+
66
+ 2. Grant of Copyright License. Subject to the terms and conditions of
67
+ this License, each Contributor hereby grants to You a perpetual,
68
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
69
+ copyright license to reproduce, prepare Derivative Works of,
70
+ publicly display, publicly perform, sublicense, and distribute the
71
+ Work and such Derivative Works in Source or Object form.
72
+
73
+ 3. Grant of Patent License. Subject to the terms and conditions of
74
+ this License, each Contributor hereby grants to You a perpetual,
75
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
76
+ (except as stated in this section) patent license to make, have made,
77
+ use, offer to sell, sell, import, and otherwise transfer the Work,
78
+ where such license applies only to those patent claims licensable
79
+ by such Contributor that are necessarily infringed by their
80
+ Contribution(s) alone or by combination of their Contribution(s)
81
+ with the Work to which such Contribution(s) was submitted. If You
82
+ institute patent litigation against any entity (including a
83
+ cross-claim or counterclaim in a lawsuit) alleging that the Work
84
+ or a Contribution incorporated within the Work constitutes direct
85
+ or contributory patent infringement, then any patent licenses
86
+ granted to You under this License for that Work shall terminate
87
+ as of the date such litigation is filed.
88
+
89
+ 4. Redistribution. You may reproduce and distribute copies of the
90
+ Work or Derivative Works thereof in any medium, with or without
91
+ modifications, and in Source or Object form, provided that You
92
+ meet the following conditions:
93
+
94
+ (a) You must give any other recipients of the Work or
95
+ Derivative Works a copy of this License; and
96
+
97
+ (b) You must cause any modified files to carry prominent notices
98
+ stating that You changed the files; and
99
+
100
+ (c) You must retain, in the Source form of any Derivative Works
101
+ that You distribute, all copyright, patent, trademark, and
102
+ attribution notices from the Source form of the Work,
103
+ excluding those notices that do not pertain to any part of
104
+ the Derivative Works; and
105
+
106
+ (d) If the Work includes a "NOTICE" text file as part of its
107
+ distribution, then any Derivative Works that You distribute must
108
+ include a readable copy of the attribution notices contained
109
+ within such NOTICE file, excluding those notices that do not
110
+ pertain to any part of the Derivative Works, in at least one
111
+ of the following places: within a NOTICE text file distributed
112
+ as part of the Derivative Works; within the Source form or
113
+ documentation, if provided along with the Derivative Works; or,
114
+ within a display generated by the Derivative Works, if and
115
+ wherever such third-party notices normally appear. The contents
116
+ of the NOTICE file are for informational purposes only and
117
+ do not modify the License. You may add Your own attribution
118
+ notices within Derivative Works that You distribute, alongside
119
+ or as an addendum to the NOTICE text from the Work, provided
120
+ that such additional attribution notices cannot be construed
121
+ as modifying the License.
122
+
123
+ You may add Your own copyright statement to Your modifications and
124
+ may provide additional or different license terms and conditions
125
+ for use, reproduction, or distribution of Your modifications, or
126
+ for any such Derivative Works as a whole, provided Your use,
127
+ reproduction, and distribution of the Work otherwise complies with
128
+ the conditions stated in this License.
129
+
130
+ 5. Submission of Contributions. Unless You explicitly state otherwise,
131
+ any Contribution intentionally submitted for inclusion in the Work
132
+ by You to the Licensor shall be under the terms and conditions of
133
+ this License, without any additional terms or conditions.
134
+ Notwithstanding the above, nothing herein shall supersede or modify
135
+ the terms of any separate license agreement you may have executed
136
+ with Licensor regarding such Contributions.
137
+
138
+ 6. Trademarks. This License does not grant permission to use the trade
139
+ names, trademarks, service marks, or product names of the Licensor,
140
+ except as required for reasonable and customary use in describing the
141
+ origin of the Work and reproducing the content of the NOTICE file.
142
+
143
+ 7. Disclaimer of Warranty. Unless required by applicable law or
144
+ agreed to in writing, Licensor provides the Work (and each
145
+ Contributor provides its Contributions) on an "AS IS" BASIS,
146
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
147
+ implied, including, without limitation, any warranties or conditions
148
+ of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
149
+ PARTICULAR PURPOSE. You are solely responsible for determining the
150
+ appropriateness of using or redistributing the Work and assume any
151
+ risks associated with Your exercise of permissions under this License.
152
+
153
+ 8. Limitation of Liability. In no event and under no legal theory,
154
+ whether in tort (including negligence), contract, or otherwise,
155
+ unless required by applicable law (such as deliberate and grossly
156
+ negligent acts) or agreed to in writing, shall any Contributor be
157
+ liable to You for damages, including any direct, indirect, special,
158
+ incidental, or consequential damages of any character arising as a
159
+ result of this License or out of the use or inability to use the
160
+ Work (including but not limited to damages for loss of goodwill,
161
+ work stoppage, computer failure or malfunction, or any and all
162
+ other commercial damages or losses), even if such Contributor
163
+ has been advised of the possibility of such damages.
164
+
165
+ 9. Accepting Warranty or Additional Liability. While redistributing
166
+ the Work or Derivative Works thereof, You may choose to offer,
167
+ and charge a fee for, acceptance of support, warranty, indemnity,
168
+ or other liability obligations and/or rights consistent with this
169
+ License. However, in accepting such obligations, You may act only
170
+ on Your own behalf and on Your sole responsibility, not on behalf
171
+ of any other Contributor, and only if You agree to indemnify,
172
+ defend, and hold each Contributor harmless for any liability
173
+ incurred by, or claims asserted against, such Contributor by reason
174
+ of your accepting any such warranty or additional liability.
175
+
176
+ END OF TERMS AND CONDITIONS
177
+
178
+ APPENDIX: How to apply the Apache License to your work.
179
+
180
+ To apply the Apache License to your work, attach the following
181
+ boilerplate notice, with the fields enclosed by brackets "[]"
182
+ replaced with your own identifying information. (Don't include
183
+ the brackets!) The text should be enclosed in the appropriate
184
+ comment syntax for the file format. We also recommend that a
185
+ file or class name and description of purpose be included on the
186
+ same "printed page" as the copyright notice for easier
187
+ identification within third-party archives.
188
+
189
+ Copyright [yyyy] [name of copyright owner]
190
+
191
+ Licensed under the Apache License, Version 2.0 (the "License");
192
+ you may not use this file except in compliance with the License.
193
+ You may obtain a copy of the License at
194
+
195
+ http://www.apache.org/licenses/LICENSE-2.0
196
+
197
+ Unless required by applicable law or agreed to in writing, software
198
+ distributed under the License is distributed on an "AS IS" BASIS,
199
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
200
+ See the License for the specific language governing permissions and
201
+ limitations under the License.
__init__.py ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ import sys
2
+ from os.path import dirname, join
3
+ sys.path.insert(0, join(dirname(__file__), '.'))
app.py CHANGED
@@ -1,16 +1,94 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import gradio as gr
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2
 
3
  title = 'FOUND'
4
  description = 'Gradio Demo accompanying paper "Unsupervised Object Localization: Observing the Background to Discover Objects"\n \
5
  The app is running CPU-only, times are therefore .\n'
6
- article = """<h1 align="center">[FOUND] Unsupervised Object Localization: Observing the Background to Discover Objects</h1>
 
7
  """
 
 
8
 
9
- def greet(name):
10
- return "Hello " + name + "!!"
 
 
 
 
 
 
 
11
 
12
- iface = gr.Interface(fn=greet, title=title, description=description,
13
- article=article, inputs="text", outputs="text")
14
- iface.launch()
15
 
16
 
 
1
+ import os
2
+ import torch
3
+ import argparse
4
+ import torch.nn as nn
5
+ import torch.nn.functional as F
6
+ import matplotlib.pyplot as plt
7
+
8
+ from PIL import Image
9
+ from model import FoundModel
10
+ from misc import load_config
11
+ from torchvision import transforms as T
12
+
13
+
14
  import gradio as gr
15
+
16
+ NORMALIZE = T.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))
17
+ CACHE = True
18
+
19
+ def blend_images(bg, fg, alpha=0.5):
20
+ fg = fg.convert('RGBA')
21
+ bg = bg.convert('RGBA')
22
+ blended = Image.blend(bg, fg, alpha=alpha)
23
+
24
+ return blended
25
+
26
+
27
+ def predict(img_input):
28
+
29
+ config = "configs/found_DUTS-TR.yaml"
30
+ model_weights = "data/weights/decoder_weights.pt"
31
+
32
+ # Configuration
33
+ config = load_config(config)
34
+
35
+ # ------------------------------------
36
+ # Load the model
37
+ model = FoundModel(vit_model=config.model["pre_training"],
38
+ vit_arch=config.model["arch"],
39
+ vit_patch_size=config.model["patch_size"],
40
+ enc_type_feats=config.found["feats"],
41
+ bkg_type_feats=config.found["feats"],
42
+ bkg_th=config.found["bkg_th"])
43
+ # Load weights
44
+ model.decoder_load_weights(model_weights)
45
+ model.eval()
46
+ print(f"Model {model_weights} loaded correctly.")
47
+
48
+ # Load the image
49
+ img_pil = Image.open(img_input)
50
+ img = img_pil.convert("RGB")
51
+
52
+ t = T.Compose([T.ToTensor(), NORMALIZE])
53
+ img_t = t(img)[None,:,:,:]
54
+ inputs = img_t.to("cuda")
55
+
56
+ # Forward step
57
+ with torch.no_grad():
58
+ preds, _, _, _ = model.forward_step(inputs, for_eval=True)
59
+
60
+ # Apply FOUND
61
+ sigmoid = nn.Sigmoid()
62
+ h, w = img_t.shape[-2:]
63
+ preds_up = F.interpolate(
64
+ preds, scale_factor=model.vit_patch_size, mode="bilinear", align_corners=False
65
+ )[..., :h, :w]
66
+ preds_up = (
67
+ (sigmoid(preds_up.detach()) > 0.5).squeeze(0).float()
68
+ )
69
+
70
+ return blend_images(img_pil, preds_up)
71
+
72
 
73
  title = 'FOUND'
74
  description = 'Gradio Demo accompanying paper "Unsupervised Object Localization: Observing the Background to Discover Objects"\n \
75
  The app is running CPU-only, times are therefore .\n'
76
+ article = """<h2 align="center">Unsupervised Object Localization: Observing the Background to Discover Objects </h2>
77
+ <h1 align="center"> FOUND </h1>
78
  """
79
+ examples = ["data/examples/VOC_000030.jpg"]
80
+
81
 
82
+ iface = gr.Interface(fn=predict,
83
+ title=title,
84
+ description=description,
85
+ article=article,
86
+ inputs=gr.Image(type='filepath'),
87
+ outputs=gr.Image(label="Object localization", type="pil"),
88
+ examples=examples,
89
+ cache_examples=CACHE
90
+ )
91
 
92
+ iface.launch(show_error=True, enable_queue=True, inline=True)
 
 
93
 
94
 
bilateral_solver.py ADDED
@@ -0,0 +1,214 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Code adapted from TokenCut: https://github.com/YangtaoWANG95/TokenCut
3
+ """
4
+
5
+ import PIL.Image as Image
6
+ import numpy as np
7
+ from scipy import ndimage
8
+ from scipy.sparse import diags, csr_matrix
9
+ from scipy.sparse.linalg import cg
10
+
11
+ RGB_TO_YUV = np.array(
12
+ [[0.299, 0.587, 0.114], [-0.168736, -0.331264, 0.5], [0.5, -0.418688, -0.081312]]
13
+ )
14
+ YUV_TO_RGB = np.array([[1.0, 0.0, 1.402], [1.0, -0.34414, -0.71414], [1.0, 1.772, 0.0]])
15
+ YUV_OFFSET = np.array([0, 128.0, 128.0]).reshape(1, 1, -1)
16
+ MAX_VAL = 255.0
17
+
18
+
19
+ def rgb2yuv(im):
20
+ return np.tensordot(im, RGB_TO_YUV, ([2], [1])) + YUV_OFFSET
21
+
22
+
23
+ def yuv2rgb(im):
24
+ return np.tensordot(im.astype(float) - YUV_OFFSET, YUV_TO_RGB, ([2], [1]))
25
+
26
+
27
+ def get_valid_idx(valid, candidates):
28
+ """Find which values are present in a list and where they are located"""
29
+ locs = np.searchsorted(valid, candidates)
30
+ # Handle edge case where the candidate is larger than all valid values
31
+ locs = np.clip(locs, 0, len(valid) - 1)
32
+ # Identify which values are actually present
33
+ valid_idx = np.flatnonzero(valid[locs] == candidates)
34
+ locs = locs[valid_idx]
35
+ return valid_idx, locs
36
+
37
+
38
+ class BilateralGrid(object):
39
+ def __init__(self, im, sigma_spatial=32, sigma_luma=8, sigma_chroma=8):
40
+ im_yuv = rgb2yuv(im)
41
+ # Compute 5-dimensional XYLUV bilateral-space coordinates
42
+ Iy, Ix = np.mgrid[: im.shape[0], : im.shape[1]]
43
+ x_coords = (Ix / sigma_spatial).astype(int)
44
+ y_coords = (Iy / sigma_spatial).astype(int)
45
+ luma_coords = (im_yuv[..., 0] / sigma_luma).astype(int)
46
+ chroma_coords = (im_yuv[..., 1:] / sigma_chroma).astype(int)
47
+ coords = np.dstack((x_coords, y_coords, luma_coords, chroma_coords))
48
+ coords_flat = coords.reshape(-1, coords.shape[-1])
49
+ self.npixels, self.dim = coords_flat.shape
50
+ # Hacky "hash vector" for coordinates,
51
+ # Requires all scaled coordinates be < MAX_VAL
52
+ self.hash_vec = MAX_VAL ** np.arange(self.dim)
53
+ # Construct S and B matrix
54
+ self._compute_factorization(coords_flat)
55
+
56
+ def _compute_factorization(self, coords_flat):
57
+ # Hash each coordinate in grid to a unique value
58
+ hashed_coords = self._hash_coords(coords_flat)
59
+ unique_hashes, unique_idx, idx = np.unique(
60
+ hashed_coords, return_index=True, return_inverse=True
61
+ )
62
+ # Identify unique set of vertices
63
+ unique_coords = coords_flat[unique_idx]
64
+ self.nvertices = len(unique_coords)
65
+ # Construct sparse splat matrix that maps from pixels to vertices
66
+ self.S = csr_matrix((np.ones(self.npixels), (idx, np.arange(self.npixels))))
67
+ # Construct sparse blur matrices.
68
+ # Note that these represent [1 0 1] blurs, excluding the central element
69
+ self.blurs = []
70
+ for d in range(self.dim):
71
+ blur = 0.0
72
+ for offset in (-1, 1):
73
+ offset_vec = np.zeros((1, self.dim))
74
+ offset_vec[:, d] = offset
75
+ neighbor_hash = self._hash_coords(unique_coords + offset_vec)
76
+ valid_coord, idx = get_valid_idx(unique_hashes, neighbor_hash)
77
+ blur = blur + csr_matrix(
78
+ (np.ones((len(valid_coord),)), (valid_coord, idx)),
79
+ shape=(self.nvertices, self.nvertices),
80
+ )
81
+ self.blurs.append(blur)
82
+
83
+ def _hash_coords(self, coord):
84
+ """Hacky function to turn a coordinate into a unique value"""
85
+ return np.dot(coord.reshape(-1, self.dim), self.hash_vec)
86
+
87
+ def splat(self, x):
88
+ return self.S.dot(x)
89
+
90
+ def slice(self, y):
91
+ return self.S.T.dot(y)
92
+
93
+ def blur(self, x):
94
+ """Blur a bilateral-space vector with a 1 2 1 kernel in each dimension"""
95
+ assert x.shape[0] == self.nvertices
96
+ out = 2 * self.dim * x
97
+ for blur in self.blurs:
98
+ out = out + blur.dot(x)
99
+ return out
100
+
101
+ def filter(self, x):
102
+ """Apply bilateral filter to an input x"""
103
+ return self.slice(self.blur(self.splat(x))) / self.slice(
104
+ self.blur(self.splat(np.ones_like(x)))
105
+ )
106
+
107
+
108
+ def bistochastize(grid, maxiter=10):
109
+ """Compute diagonal matrices to bistochastize a bilateral grid"""
110
+ m = grid.splat(np.ones(grid.npixels))
111
+ n = np.ones(grid.nvertices)
112
+ for i in range(maxiter):
113
+ n = np.sqrt(n * m / grid.blur(n))
114
+ # Correct m to satisfy the assumption of bistochastization regardless
115
+ # of how many iterations have been run.
116
+ m = n * grid.blur(n)
117
+ Dm = diags(m, 0)
118
+ Dn = diags(n, 0)
119
+ return Dn, Dm
120
+
121
+
122
+ class BilateralSolver(object):
123
+ def __init__(self, grid, params):
124
+ self.grid = grid
125
+ self.params = params
126
+ self.Dn, self.Dm = bistochastize(grid)
127
+
128
+ def solve(self, x, w):
129
+ # Check that w is a vector or a nx1 matrix
130
+ if w.ndim == 2:
131
+ assert w.shape[1] == 1
132
+ elif w.dim == 1:
133
+ w = w.reshape(w.shape[0], 1)
134
+ A_smooth = self.Dm - self.Dn.dot(self.grid.blur(self.Dn))
135
+ w_splat = self.grid.splat(w)
136
+ A_data = diags(w_splat[:, 0], 0)
137
+ A = self.params["lam"] * A_smooth + A_data
138
+ xw = x * w
139
+ b = self.grid.splat(xw)
140
+ # Use simple Jacobi preconditioner
141
+ A_diag = np.maximum(A.diagonal(), self.params["A_diag_min"])
142
+ M = diags(1 / A_diag, 0)
143
+ # Flat initialization
144
+ y0 = self.grid.splat(xw) / w_splat
145
+ yhat = np.empty_like(y0)
146
+ for d in range(x.shape[-1]):
147
+ yhat[..., d], info = cg(
148
+ A,
149
+ b[..., d],
150
+ x0=y0[..., d],
151
+ M=M,
152
+ maxiter=self.params["cg_maxiter"],
153
+ tol=self.params["cg_tol"],
154
+ )
155
+ xhat = self.grid.slice(yhat)
156
+ return xhat
157
+
158
+
159
+ def bilateral_solver_output(
160
+ img_pth,
161
+ target,
162
+ img=None,
163
+ sigma_spatial=24,
164
+ sigma_luma=4,
165
+ sigma_chroma=4,
166
+ get_all_cc=False
167
+ ):
168
+ if img is None:
169
+ reference = np.array(Image.open(img_pth).convert("RGB"))
170
+ else:
171
+ reference = np.array(img)
172
+
173
+ h, w = target.shape
174
+ confidence = np.ones((h, w)) * 0.999
175
+
176
+ grid_params = {
177
+ "sigma_luma": sigma_luma, # Brightness bandwidth
178
+ "sigma_chroma": sigma_chroma, # Color bandwidth
179
+ "sigma_spatial": sigma_spatial, # Spatial bandwidth
180
+ }
181
+
182
+ bs_params = {
183
+ "lam": 256, # The strength of the smoothness parameter
184
+ "A_diag_min": 1e-5, # Clamp the diagonal of the A diagonal in the Jacobi preconditioner.
185
+ "cg_tol": 1e-5, # The tolerance on the convergence in PCG
186
+ "cg_maxiter": 25, # The number of PCG iterations
187
+ }
188
+
189
+ grid = BilateralGrid(reference, **grid_params)
190
+
191
+ t = target.reshape(-1, 1).astype(np.double)
192
+ c = confidence.reshape(-1, 1).astype(np.double)
193
+
194
+ ## output solver, which is a soft value
195
+ output_solver = BilateralSolver(grid, bs_params).solve(t, c).reshape((h, w))
196
+
197
+ binary_solver = ndimage.binary_fill_holes(output_solver > 0.5)
198
+ labeled, nr_objects = ndimage.label(binary_solver)
199
+
200
+ nb_pixel = [np.sum(labeled == i) for i in range(nr_objects + 1)]
201
+ pixel_order = np.argsort(nb_pixel)
202
+
203
+ if get_all_cc:
204
+ # Remove known bakground
205
+ pixel_descending_order = pixel_order[::-1]
206
+ # Get all CC expect biggest one, may consider it as background, try and change here
207
+ binary_solver = (labeled[None,:,:] == pixel_descending_order[1:,None,None]).astype(int).sum(0)
208
+ else:
209
+ try:
210
+ binary_solver = labeled == pixel_order[-2]
211
+ except:
212
+ binary_solver = np.ones((h, w), dtype=bool)
213
+
214
+ return output_solver, binary_solver
bkg_seg.py ADDED
@@ -0,0 +1,84 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2022 - Valeo Comfort and Driving Assistance - Oriane Siméoni @ valeo.ai
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import torch
16
+ import torch.nn.functional as F
17
+
18
+ from typing import Tuple
19
+
20
+ def compute_img_bkg_seg(
21
+ attentions,
22
+ feats,
23
+ featmap_dims,
24
+ th_bkg,
25
+ dim=64,
26
+ epsilon: float = 1e-10,
27
+ apply_weights: bool = True,
28
+ ) -> Tuple[torch.Tensor, float]:
29
+ """
30
+ inputs
31
+ - attentions [B, ]
32
+ """
33
+
34
+ w_featmap, h_featmap = featmap_dims
35
+
36
+ nb, nh, _ = attentions.shape[:3]
37
+ # we keep only the output patch attention
38
+ att = attentions[:, :, 0, 1:].reshape(nb, nh, -1)
39
+ att = att.reshape(nb, nh, w_featmap, h_featmap)
40
+
41
+ # -----------------------------------------------
42
+ # Inspired by CroW sparsity channel weighting of each head CroW, Kalantidis etal.
43
+ threshold = torch.mean(att.reshape(nb, -1), dim=1) # Find threshold per image
44
+ Q = torch.sum(
45
+ att.reshape(nb, nh, w_featmap * h_featmap) > threshold[:, None, None], axis=2
46
+ ) / (w_featmap * h_featmap)
47
+ beta = torch.log(torch.sum(Q + epsilon, dim=1)[:, None] / (Q + epsilon))
48
+
49
+ # Weight features based on attention sparsity
50
+ descs = feats[:,1:,]
51
+ if apply_weights:
52
+ descs = (descs.reshape(nb, -1, nh, dim) * beta[:, None, :, None]).reshape(
53
+ nb, -1, nh * dim
54
+ )
55
+ else:
56
+ descs = (descs.reshape(nb, -1, nh, dim)).reshape(
57
+ nb, -1, nh * dim
58
+ )
59
+
60
+ # -----------------------------------------------
61
+ # Compute cosine-similarities
62
+ descs = F.normalize(descs, dim=-1, p=2)
63
+ cos_sim = torch.bmm(descs, descs.permute(0, 2, 1))
64
+
65
+ # -----------------------------------------------
66
+ # Find pixel with least amount of attention
67
+ if apply_weights:
68
+ att = att.reshape(nb, nh, w_featmap, h_featmap) * beta[:, :, None, None]
69
+ else:
70
+ att = att.reshape(nb, nh, w_featmap, h_featmap)
71
+ id_pixel_ref = torch.argmin(torch.sum(att, axis=1).reshape(nb, -1), dim=-1)
72
+
73
+ # -----------------------------------------------
74
+ # Mask of definitely background pixels: 1 on the background
75
+ cos_sim = cos_sim.reshape(nb, -1, w_featmap * h_featmap)
76
+
77
+ bkg_mask = (
78
+ cos_sim[torch.arange(cos_sim.size(0)), id_pixel_ref, :].reshape(
79
+ nb, w_featmap, h_featmap
80
+ )
81
+ > th_bkg
82
+ ) # mask to be used to remove background
83
+
84
+ return bkg_mask.float()
configs/found_DUTS-TR.yaml ADDED
@@ -0,0 +1,34 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ model:
2
+ arch: vit_small
3
+ patch_size: 8
4
+ pre_training: dino
5
+
6
+ found:
7
+ bkg_th: 0.3
8
+ feats: k
9
+
10
+ training:
11
+ dataset: DUTS-TR
12
+ dataset_set: null
13
+
14
+ # Hyper params
15
+ seed: 0
16
+ max_iter: 500
17
+ nb_epochs: 3
18
+ batch_size: 50
19
+ lr0: 5e-2
20
+ step_lr_size: 50
21
+ step_lr_gamma: 0.95
22
+ w_bs_loss: 1.5
23
+ stop_bkg_loss: 100
24
+
25
+ # Augmentations
26
+ crop_size: 224
27
+ scale_range: [0.1, 3.0]
28
+ photometric_aug: gaussian_blur
29
+ proba_photometric_aug: 0.5
30
+ cropping_strategy: random_scale
31
+
32
+ evaluation:
33
+ type: saliency # uod, retrieval
34
+ datasets: [DUT-OMRON, ECSSD]
data/examples/VOC_000030.jpg ADDED
data/weights/decoder_weights.pt ADDED
Binary file (2.69 kB). View file
 
datasets/VOC.py ADDED
@@ -0,0 +1,80 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from typing import Optional, Tuple, Union, Dict, List
3
+
4
+ import cv2
5
+ from pycocotools.coco import COCO
6
+ import numpy as np
7
+ import torch
8
+ import torchvision
9
+ from PIL import Image, PngImagePlugin
10
+ from torch.utils.data import Dataset
11
+ from torchvision import transforms as T
12
+ from torchvision.transforms import ColorJitter, RandomApply, RandomGrayscale
13
+ from tqdm import tqdm
14
+
15
+ VOCDetectionMetadataType = Dict[str, Dict[str, Union[str, Dict[str, str], List[str]]]]
16
+
17
+ def get_voc_detection_gt(
18
+ metadata: VOCDetectionMetadataType, remove_hards: bool = False
19
+ ) -> Tuple[np.array, List[str]]:
20
+ objects = metadata["annotation"]["object"]
21
+ nb_obj = len(objects)
22
+
23
+ gt_bbxs = []
24
+ gt_clss = []
25
+ for object in range(nb_obj):
26
+ if remove_hards and (
27
+ objects[object]["truncated"] == "1"
28
+ or objects[object]["difficult"] == "1"
29
+ ):
30
+ continue
31
+
32
+ gt_cls = objects[object]["name"]
33
+ gt_clss.append(gt_cls)
34
+ obj = objects[object]["bndbox"]
35
+ x1y1x2y2 = [
36
+ int(obj["xmin"]),
37
+ int(obj["ymin"]),
38
+ int(obj["xmax"]),
39
+ int(obj["ymax"]),
40
+ ]
41
+
42
+ # Original annotations are integers in the range [1, W or H]
43
+ # Assuming they mean 1-based pixel indices (inclusive),
44
+ # a box with annotation (xmin=1, xmax=W) covers the whole image.
45
+ # In coordinate space this is represented by (xmin=0, xmax=W)
46
+ x1y1x2y2[0] -= 1
47
+ x1y1x2y2[1] -= 1
48
+ gt_bbxs.append(x1y1x2y2)
49
+
50
+ return np.asarray(gt_bbxs), gt_clss
51
+
52
+ def create_gt_masks_if_voc(labels: PngImagePlugin.PngImageFile) -> Image.Image:
53
+ mask = np.array(labels)
54
+ mask_gt = (mask > 0).astype(float)
55
+ mask_gt = np.where(mask_gt != 0.0, 255, mask_gt)
56
+ mask_gt = Image.fromarray(np.uint8(mask_gt))
57
+ return mask_gt
58
+
59
+ def create_VOC_loader(img_dir, dataset_set, evaluation_type):
60
+ year = img_dir[-4:]
61
+ download = not os.path.exists(img_dir)
62
+ if evaluation_type == "uod":
63
+ loader = torchvision.datasets.VOCDetection(
64
+ img_dir,
65
+ year=year,
66
+ image_set=dataset_set,
67
+ transform=None,
68
+ download=download,
69
+ )
70
+ elif evaluation_type == "saliency":
71
+ loader = torchvision.datasets.VOCSegmentation(
72
+ img_dir,
73
+ year=year,
74
+ image_set=dataset_set,
75
+ transform=None,
76
+ download=download,
77
+ )
78
+ else:
79
+ raise ValueError(f"Not implemented for {evaluation_type}.")
80
+ return loader
datasets/__init__.py ADDED
File without changes
datasets/augmentations.py ADDED
@@ -0,0 +1,68 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Code borrowed from SelfMask: https://github.com/NoelShin/selfmask
3
+ """
4
+
5
+ import numpy as np
6
+ import torch
7
+ from PIL import Image
8
+ from typing import Optional, Tuple, Union
9
+ from torchvision.transforms import ColorJitter, RandomApply, RandomGrayscale
10
+
11
+ from datasets.utils import GaussianBlur
12
+ from datasets.geometric_transforms import (
13
+ random_scale,
14
+ random_crop,
15
+ random_hflip,
16
+ )
17
+
18
+ def geometric_augmentations(
19
+ image: Image.Image,
20
+ random_scale_range: Optional[Tuple[float, float]] = None,
21
+ random_crop_size: Optional[int] = None,
22
+ random_hflip_p: Optional[float] = None,
23
+ mask: Optional[Union[Image.Image, np.ndarray, torch.Tensor]] = None,
24
+ ignore_index: Optional[int] = None,
25
+ ) -> Tuple[Image.Image, torch.Tensor]:
26
+ """Note. image and mask are assumed to be of base size, thus share a spatial shape."""
27
+ if random_scale_range is not None:
28
+ image, mask = random_scale(
29
+ image=image, random_scale_range=random_scale_range, mask=mask
30
+ )
31
+
32
+ if random_crop_size is not None:
33
+ crop_size = (random_crop_size, random_crop_size)
34
+ fill = tuple(np.array(image).mean(axis=(0, 1)).astype(np.uint8).tolist())
35
+ image, offset = random_crop(image=image, crop_size=crop_size, fill=fill)
36
+
37
+ if mask is not None:
38
+ assert ignore_index is not None
39
+ mask = random_crop(
40
+ image=mask, crop_size=crop_size, fill=ignore_index, offset=offset
41
+ )[0]
42
+
43
+ if random_hflip_p is not None:
44
+ image, mask = random_hflip(image=image, p=random_hflip_p, mask=mask)
45
+ return image, mask
46
+
47
+ def photometric_augmentations(
48
+ image: Image.Image,
49
+ random_color_jitter: bool,
50
+ random_grayscale: bool,
51
+ random_gaussian_blur: bool,
52
+ proba_photometric_aug: float,
53
+ ) -> torch.Tensor:
54
+ if random_color_jitter:
55
+ color_jitter = ColorJitter(
56
+ brightness=0.8, contrast=0.8, saturation=0.8, hue=0.2
57
+ )
58
+ image = RandomApply([color_jitter], p=proba_photometric_aug)(image)
59
+
60
+ if random_grayscale:
61
+ image = RandomGrayscale(proba_photometric_aug)(image)
62
+
63
+ if random_gaussian_blur:
64
+ w, h = image.size
65
+ image = GaussianBlur(kernel_size=int((0.1 * min(w, h) // 2 * 2) + 1))(
66
+ image, proba_photometric_aug
67
+ )
68
+ return image
datasets/datasets.py ADDED
@@ -0,0 +1,409 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Dataset functions for applying Normalized Cut.
3
+ Code adapted from SelfMask: https://github.com/NoelShin/selfmask
4
+ """
5
+
6
+ import os
7
+ from typing import Optional, Tuple, Union
8
+
9
+ from pycocotools.coco import COCO
10
+ import numpy as np
11
+ import torch
12
+ import torchvision
13
+ from PIL import Image
14
+ from torch.utils.data import Dataset
15
+ from torchvision import transforms as T
16
+
17
+ from datasets.utils import unnormalize
18
+ from datasets.geometric_transforms import resize
19
+ from datasets.VOC import get_voc_detection_gt, create_gt_masks_if_voc, create_VOC_loader
20
+ from datasets.augmentations import geometric_augmentations, photometric_augmentations
21
+
22
+ from datasets.uod_datasets import UODDataset
23
+
24
+ NORMALIZE = T.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))
25
+
26
+ def set_dataset_dir(dataset_name, root_dir):
27
+ if dataset_name == "ECSSD":
28
+ dataset_dir = os.path.join(root_dir, "ECSSD")
29
+ img_dir = os.path.join(dataset_dir, "images")
30
+ gt_dir = os.path.join(dataset_dir, "ground_truth_mask")
31
+
32
+ elif dataset_name == "DUTS-TEST":
33
+ dataset_dir = os.path.join(root_dir, "DUTS")
34
+ img_dir = os.path.join(dataset_dir, "DUTS-TE-Image")
35
+ gt_dir = os.path.join(dataset_dir, "DUTS-TE-Mask")
36
+
37
+ elif dataset_name == "DUTS-TR":
38
+ dataset_dir = os.path.join(root_dir, "DUTS")
39
+ img_dir = os.path.join(dataset_dir, "DUTS-TR-Image")
40
+ gt_dir = os.path.join(dataset_dir, "DUTS-TR-Mask")
41
+
42
+ elif dataset_name == "DUT-OMRON":
43
+ dataset_dir = os.path.join(root_dir, "DUT-OMRON")
44
+ img_dir = os.path.join(dataset_dir, "DUT-OMRON-image")
45
+ gt_dir = os.path.join(dataset_dir, "pixelwiseGT-new-PNG")
46
+
47
+ elif dataset_name == "VOC07":
48
+ dataset_dir = os.path.join(root_dir, "VOC2007")
49
+ img_dir = dataset_dir
50
+ gt_dir = dataset_dir
51
+
52
+ elif dataset_name == "VOC12":
53
+ dataset_dir = os.path.join('/datasets_local/osimeoni', "VOC2012")
54
+ img_dir = dataset_dir
55
+ gt_dir = dataset_dir
56
+
57
+ elif dataset_name == "COCO17":
58
+ dataset_dir = os.path.join(root_dir, "COCO")
59
+ img_dir = dataset_dir
60
+ gt_dir = dataset_dir
61
+
62
+ elif dataset_name == "ImageNet":
63
+ dataset_dir = os.path.join(root_dir, "ImageNet")
64
+ img_dir = dataset_dir
65
+ gt_dir = dataset_dir
66
+
67
+ else:
68
+ raise ValueError(f"Unknown dataset {dataset_name}")
69
+
70
+ return img_dir, gt_dir
71
+
72
+
73
+ def build_dataset(
74
+ root_dir: str,
75
+ dataset_name: str,
76
+ dataset_set: Optional[str] = None,
77
+ for_eval: bool = False,
78
+ config=None,
79
+ evaluation_type="saliency", # uod,
80
+ ):
81
+ """
82
+ Build dataset
83
+ """
84
+
85
+ if evaluation_type == "saliency":
86
+ img_dir, gt_dir = set_dataset_dir(dataset_name, root_dir)
87
+
88
+ dataset = FoundDataset(
89
+ name=dataset_name,
90
+ img_dir=img_dir,
91
+ gt_dir=gt_dir,
92
+ dataset_set=dataset_set,
93
+ config=config,
94
+ for_eval=for_eval,
95
+ evaluation_type=evaluation_type,
96
+ )
97
+
98
+ elif evaluation_type == "uod":
99
+ assert dataset_name in ["VOC07", "VOC12", "COCO20k"]
100
+ dataset_set = "trainval" if dataset_name in ["VOC07", "VOC12"] else "train"
101
+ no_hards = False
102
+ dataset = UODDataset(
103
+ dataset_name,
104
+ dataset_set,
105
+ root_dir=root_dir,
106
+ remove_hards=no_hards,
107
+ )
108
+
109
+ return dataset
110
+
111
+
112
+ class FoundDataset(Dataset):
113
+ def __init__(
114
+ self,
115
+ name: str,
116
+ img_dir: str,
117
+ gt_dir: str,
118
+ dataset_set: Optional[str] = None,
119
+ config=None,
120
+ for_eval:bool = False,
121
+ evaluation_type:str = "saliency",
122
+ ) -> None:
123
+ """
124
+ Args:
125
+ root_dir (string): Directory with all the images.
126
+ transform (callable, optional): Optional transform to be applied
127
+ on a sample.
128
+ """
129
+ self.for_eval = for_eval
130
+ self.use_aug = not for_eval
131
+ self.evaluation_type = evaluation_type
132
+
133
+ assert evaluation_type in ["saliency"]
134
+
135
+ self.name = name
136
+ self.dataset_set = dataset_set
137
+ self.img_dir = img_dir
138
+ self.gt_dir = gt_dir
139
+
140
+ # if VOC dataset
141
+ self.loader = None
142
+ self.cocoGt = None
143
+
144
+ self.config = config
145
+
146
+ if "VOC" in self.name:
147
+ self.loader = create_VOC_loader(self.img_dir, dataset_set, evaluation_type)
148
+
149
+ # if ImageNet dataset
150
+ elif "ImageNet" in self.name:
151
+ self.loader = torchvision.datasets.ImageNet(
152
+ self.img_dir,
153
+ split=dataset_set,
154
+ transform=None,
155
+ target_transform=None,
156
+ )
157
+
158
+ elif "COCO" in self.name:
159
+ year = int("20"+self.name[-2:])
160
+ annFile=f'/datasets_local/COCO/annotations/instances_{dataset_set}{str(year)}.json'
161
+ self.cocoGt=COCO(annFile)
162
+ self.img_ids = list(sorted(self.cocoGt.getImgIds()))
163
+ self.img_dir = f'/datasets_local/COCO/images/{dataset_set}{str(year)}/'
164
+
165
+ # Transformations
166
+ if self.for_eval:
167
+ full_img_transform, no_norm_full_img_transform = self.get_init_transformation(
168
+ isVOC="VOC" in name
169
+ )
170
+ self.full_img_transform = full_img_transform
171
+ self.no_norm_full_img_transform = no_norm_full_img_transform
172
+
173
+ # Images
174
+ self.list_images = None
175
+ if not "VOC" in self.name and not "COCO" in self.name:
176
+ self.list_images = [
177
+ os.path.join(img_dir, i) for i in sorted(os.listdir(img_dir))
178
+ ]
179
+
180
+ self.ignore_index = -1
181
+ self.mean = NORMALIZE.mean
182
+ self.std = NORMALIZE.std
183
+ self.to_tensor_and_normalize = T.Compose([T.ToTensor(), NORMALIZE])
184
+ self.normalize = NORMALIZE
185
+
186
+ if config is not None and self.use_aug:
187
+ self._set_aug(config)
188
+
189
+
190
+ def get_init_transformation(self, isVOC: bool = False):
191
+ if isVOC:
192
+ t = T.Compose([T.PILToTensor(), T.ConvertImageDtype(torch.float), NORMALIZE])
193
+ t_nonorm = T.Compose([T.PILToTensor(), T.ConvertImageDtype(torch.float)])
194
+ return t, t_nonorm
195
+
196
+ else:
197
+ t = T.Compose([T.ToTensor(), NORMALIZE])
198
+ t_nonorm = T.Compose([T.ToTensor()])
199
+ return t, t_nonorm
200
+
201
+ def _set_aug(self, config):
202
+ """
203
+ Set augmentation based on config.
204
+ """
205
+
206
+ photometric_aug = config.training["photometric_aug"]
207
+
208
+ self.cropping_strategy = config.training["cropping_strategy"]
209
+ if self.cropping_strategy == "center_crop":
210
+ self.use_aug = False # default strategy, not considered to be a data aug
211
+ self.scale_range = config.training["scale_range"]
212
+ self.crop_size = config.training["crop_size"]
213
+ self.center_crop_transforms = T.Compose(
214
+ [
215
+ T.CenterCrop((self.crop_size, self.crop_size)),
216
+ T.ToTensor(),
217
+ ]
218
+ )
219
+ self.center_crop_only_transforms = T.Compose(
220
+ [T.CenterCrop((self.crop_size, self.crop_size)), T.PILToTensor()]
221
+ )
222
+
223
+ self.proba_photometric_aug = config.training["proba_photometric_aug"]
224
+
225
+ self.random_color_jitter = False
226
+ self.random_grayscale = False
227
+ self.random_gaussian_blur = False
228
+ if photometric_aug == "color_jitter":
229
+ self.random_color_jitter = True
230
+ elif photometric_aug == "grayscale":
231
+ self.random_grayscale = True
232
+ elif photometric_aug == "gaussian_blur":
233
+ self.random_gaussian_blur = True
234
+
235
+ def _preprocess_data_aug(
236
+ self,
237
+ image: Image.Image,
238
+ mask: Image.Image,
239
+ ignore_index: Optional[int] = None,
240
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
241
+ """Prepare data in a proper form for either training (data augmentation) or validation."""
242
+
243
+ # resize to base size
244
+ image = resize(
245
+ image,
246
+ size=self.crop_size,
247
+ edge="shorter",
248
+ interpolation="bilinear",
249
+ )
250
+ mask = resize(
251
+ mask,
252
+ size=self.crop_size,
253
+ edge="shorter",
254
+ interpolation="bilinear",
255
+ )
256
+
257
+ if not isinstance(mask, torch.Tensor):
258
+ mask: torch.Tensor = torch.tensor(np.array(mask))
259
+
260
+ random_scale_range = None
261
+ random_crop_size = None
262
+ random_hflip_p = None
263
+ if self.cropping_strategy == "random_scale":
264
+ random_scale_range = self.scale_range
265
+ elif self.cropping_strategy == "random_crop":
266
+ random_crop_size = self.crop_size
267
+ elif self.cropping_strategy == "random_hflip":
268
+ random_hflip_p = 0.5
269
+ elif self.cropping_strategy == "random_crop_and_hflip":
270
+ random_hflip_p = 0.5
271
+ random_crop_size = self.crop_size
272
+
273
+ if random_crop_size or random_hflip_p or random_scale_range:
274
+ image, mask = geometric_augmentations(
275
+ image=image,
276
+ mask=mask,
277
+ random_scale_range=random_scale_range,
278
+ random_crop_size=random_crop_size,
279
+ ignore_index=ignore_index,
280
+ random_hflip_p=random_hflip_p,
281
+ )
282
+
283
+ if random_scale_range:
284
+ # resize to (self.crop_size, self.crop_size)
285
+ image = resize(
286
+ image,
287
+ size=self.crop_size,
288
+ interpolation="bilinear",
289
+ )
290
+ mask = resize(
291
+ mask,
292
+ size=(self.crop_size, self.crop_size),
293
+ interpolation="bilinear",
294
+ )
295
+
296
+ image = photometric_augmentations(
297
+ image,
298
+ random_color_jitter=self.random_color_jitter,
299
+ random_grayscale=self.random_grayscale,
300
+ random_gaussian_blur=self.random_gaussian_blur,
301
+ proba_photometric_aug=self.proba_photometric_aug,
302
+ )
303
+
304
+ # to tensor + normalize image
305
+ image = self.to_tensor_and_normalize(image)
306
+
307
+ return image, mask
308
+
309
+ def __len__(self) -> int:
310
+ if "VOC" in self.name:
311
+ return len(self.loader)
312
+ elif "ImageNet" in self.name:
313
+ return len(self.loader)
314
+ elif "COCO" in self.name:
315
+ return len(self.img_ids)
316
+ return len(self.list_images)
317
+
318
+ def _apply_center_crop(
319
+ self, image: Image.Image, mask: Union[Image.Image, np.ndarray, torch.Tensor]
320
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
321
+ img_t = self.center_crop_transforms(image)
322
+ # need to normalize image
323
+ img_t = self.normalize(img_t)
324
+ mask_gt = self.center_crop_transforms(mask).squeeze()
325
+ return img_t, mask_gt
326
+
327
+
328
+ def __getitem__(self, idx, get_mask_gt=True):
329
+ if "VOC" in self.name:
330
+ img, gt_labels = self.loader[idx]
331
+ if self.evaluation_type == "uod":
332
+ gt_labels, _ = get_voc_detection_gt(
333
+ gt_labels, remove_hards=False
334
+ )
335
+ elif self.evaluation_type == "saliency":
336
+ mask_gt = create_gt_masks_if_voc(gt_labels)
337
+ img_path = self.loader.images[idx]
338
+
339
+ elif "ImageNet" in self.name:
340
+ img, _ = self.loader[idx]
341
+ img_path = self.loader.imgs[idx][0]
342
+ # empty mask since no gt mask, only class label
343
+ zeros = np.zeros(np.array(img).shape[:2])
344
+ mask_gt = Image.fromarray(zeros)
345
+
346
+ elif "COCO" in self.name:
347
+ img_id = self.img_ids[idx]
348
+
349
+ path = self.cocoGt.loadImgs(img_id)[0]["file_name"]
350
+ img = Image.open(os.path.join(self.img_dir, path)).convert("RGB")
351
+ _ = self.cocoGt.loadAnns(self.cocoGt.getAnnIds(id))
352
+ img_path = self.img_ids[idx] # What matters most is the id for eval
353
+
354
+ # empty mask since no gt mask, only class label
355
+ zeros = np.zeros(np.array(img).shape[:2])
356
+ mask_gt = Image.fromarray(zeros)
357
+
358
+ # For all others
359
+ else:
360
+ img_path = self.list_images[idx]
361
+ with open(img_path, "rb") as f:
362
+ img = Image.open(f)
363
+ img = img.convert("RGB")
364
+ im_name = img_path.split("/")[-1]
365
+ mask_gt = Image.open(
366
+ os.path.join(self.gt_dir, im_name.replace(".jpg", ".png"))
367
+ ).convert("L")
368
+
369
+ if self.for_eval:
370
+ img_t = self.full_img_transform(img)
371
+ img_init = self.no_norm_full_img_transform(img)
372
+
373
+ if self.evaluation_type == "saliency":
374
+ mask_gt = torch.tensor(np.array(mask_gt)).squeeze()
375
+ mask_gt = np.array(mask_gt)
376
+ mask_gt = mask_gt == 255
377
+ mask_gt = torch.tensor(mask_gt)
378
+ else:
379
+ if self.use_aug:
380
+ img_t, mask_gt = self._preprocess_data_aug(
381
+ image=img, mask=mask_gt, ignore_index=self.ignore_index
382
+ )
383
+ mask_gt = np.array(mask_gt)
384
+ mask_gt = mask_gt == 255
385
+ mask_gt = torch.tensor(mask_gt)
386
+ else:
387
+ # no data aug
388
+ img_t, mask_gt = self._apply_center_crop(image=img, mask=mask_gt)
389
+ gt_labels = self.center_crop_only_transforms(gt_labels).squeeze()
390
+ mask_gt = np.asarray(mask_gt, np.int64)
391
+ mask_gt = mask_gt == 1
392
+ mask_gt = torch.tensor(mask_gt)
393
+
394
+ img_init = unnormalize(img_t)
395
+
396
+ if not get_mask_gt:
397
+ mask_gt = None
398
+
399
+ if self.evaluation_type == "uod":
400
+ gt_labels = torch.tensor(gt_labels)
401
+ mask_gt = gt_labels
402
+
403
+ return img_t, img_init, mask_gt, img_path
404
+
405
+ def fullimg_mode(self):
406
+ self.val_full_image = True
407
+
408
+ def training_mode(self):
409
+ self.val_full_image = False
datasets/geometric_transforms.py ADDED
@@ -0,0 +1,160 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Code adapted from SelfMask: https://github.com/NoelShin/selfmask
3
+ """
4
+
5
+ from random import randint, random, uniform
6
+ from typing import Optional, Tuple, Union
7
+
8
+ import numpy as np
9
+ import torch
10
+ import torchvision.transforms.functional as TF
11
+ from PIL import Image
12
+ from torchvision.transforms.functional import InterpolationMode as IM
13
+
14
+
15
+ def random_crop(
16
+ image: Union[Image.Image, np.ndarray, torch.Tensor],
17
+ crop_size: Tuple[int, int], # (h, w)
18
+ fill: Union[int, Tuple[int, int, int]], # an unsigned integer or RGB,
19
+ offset: Optional[Tuple[int, int]] = None, # (top, left) coordinate of a crop
20
+ ):
21
+ assert type(crop_size) in (tuple, list) and len(crop_size) == 2
22
+
23
+ if isinstance(image, np.ndarray):
24
+ image = torch.tensor(image)
25
+ h, w = image.shape[-2:]
26
+ elif isinstance(image, Image.Image):
27
+ w, h = image.size
28
+ elif isinstance(image, torch.Tensor):
29
+ h, w = image.shape[-2:]
30
+ else:
31
+ raise TypeError(type(image))
32
+
33
+ pad_h, pad_w = max(crop_size[0] - h, 0), max(crop_size[1] - w, 0)
34
+
35
+ image = TF.pad(image, [0, 0, pad_w, pad_h], fill=fill, padding_mode="constant")
36
+
37
+ if isinstance(image, Image.Image):
38
+ w, h = image.size
39
+ else:
40
+ h, w = image.shape[-2:]
41
+
42
+ if offset is None:
43
+ offset = (randint(0, h - crop_size[0]), randint(0, w - crop_size[1]))
44
+
45
+ image = TF.crop(
46
+ image, top=offset[0], left=offset[1], height=crop_size[0], width=crop_size[1]
47
+ )
48
+ return image, offset
49
+
50
+
51
+ def compute_size(
52
+ input_size: Tuple[int, int], output_size: int, edge: str # h, w
53
+ ) -> Tuple[int, int]:
54
+ assert edge in ["shorter", "longer"]
55
+ h, w = input_size
56
+
57
+ if edge == "longer":
58
+ if w > h:
59
+ h = int(float(h) / w * output_size)
60
+ w = output_size
61
+ else:
62
+ w = int(float(w) / h * output_size)
63
+ h = output_size
64
+ assert w <= output_size and h <= output_size
65
+
66
+ else:
67
+ if w > h:
68
+ w = int(float(w) / h * output_size)
69
+ h = output_size
70
+ else:
71
+ h = int(float(h) / w * output_size)
72
+ w = output_size
73
+ assert w >= output_size and h >= output_size
74
+ return h, w
75
+
76
+
77
+ def resize(
78
+ image: Union[Image.Image, np.ndarray, torch.Tensor],
79
+ size: Union[int, Tuple[int, int]],
80
+ interpolation: str,
81
+ edge: str = "both",
82
+ ) -> Union[Image.Image, torch.Tensor]:
83
+ """
84
+ :param image: an image to be resized
85
+ :param size: a resulting image size
86
+ :param interpolation: sampling mode. ["nearest", "bilinear", "bicubic"]
87
+ :param edge: Default: "both"
88
+ No-op if a size is given as a tuple (h, w).
89
+ If set to "both", resize both height and width to the specified size.
90
+ If set to "shorter", resize the shorter edge to the specified size keeping the aspect ratio.
91
+ If set to "longer", resize the longer edge to the specified size keeping the aspect ratio.
92
+ :return: a resized image
93
+ """
94
+ assert interpolation in ["nearest", "bilinear", "bicubic"], ValueError(
95
+ interpolation
96
+ )
97
+ assert edge in ["both", "shorter", "longer"], ValueError(edge)
98
+ interpolation = {
99
+ "nearest": IM.NEAREST,
100
+ "bilinear": IM.BILINEAR,
101
+ "bicubic": IM.BICUBIC,
102
+ }[interpolation]
103
+
104
+ if type(image) == torch.Tensor:
105
+ image = image.clone().detach()
106
+ elif type(image) == np.ndarray:
107
+ image = torch.from_numpy(image)
108
+
109
+ if type(size) is tuple:
110
+ if type(image) == torch.Tensor and len(image.shape) == 2:
111
+ image = TF.resize(
112
+ image.unsqueeze(dim=0), size=size, interpolation=interpolation
113
+ ).squeeze(dim=0)
114
+ else:
115
+ image = TF.resize(image, size=size, interpolation=interpolation)
116
+
117
+ else:
118
+ if edge == "both":
119
+ image = TF.resize(image, size=[size, size], interpolation=interpolation)
120
+
121
+ else:
122
+ if isinstance(image, Image.Image):
123
+ w, h = image.size
124
+ else:
125
+ h, w = image.shape[-2:]
126
+ rh, rw = compute_size(input_size=(h, w), output_size=size, edge=edge)
127
+ image = TF.resize(image, size=[rh, rw], interpolation=interpolation)
128
+ return image
129
+
130
+
131
+ def random_scale(
132
+ image: Union[Image.Image, np.ndarray, torch.Tensor],
133
+ random_scale_range: Tuple[float, float],
134
+ mask: Optional[Union[Image.Image, np.ndarray, torch.Tensor]] = None,
135
+ ):
136
+ scale = uniform(*random_scale_range)
137
+ if isinstance(image, Image.Image):
138
+ w, h = image.size
139
+ else:
140
+ h, w = image.shape[-2:]
141
+ w_rs, h_rs = int(w * scale), int(h * scale)
142
+ image: Image.Image = resize(image, size=(h_rs, w_rs), interpolation="bilinear")
143
+ if mask is not None:
144
+ mask = resize(mask, size=(h_rs, w_rs), interpolation="nearest")
145
+ return image, mask
146
+
147
+
148
+ def random_hflip(
149
+ image: Union[Image.Image, np.ndarray, torch.Tensor],
150
+ p: float,
151
+ mask: Optional[Union[np.ndarray, torch.Tensor]] = None,
152
+ ):
153
+ assert 0.0 <= p <= 1.0, ValueError(random_hflip)
154
+
155
+ # Return a random floating point number in the range [0.0, 1.0).
156
+ if random() > p:
157
+ image = TF.hflip(image)
158
+ if mask is not None:
159
+ mask = TF.hflip(mask)
160
+ return image, mask
datasets/uod_datasets.py ADDED
@@ -0,0 +1,384 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2021 - Valeo Comfort and Driving Assistance - Oriane Siméoni @ valeo.ai
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ """
16
+ Code adapted from previous method LOST: https://github.com/valeoai/LOST
17
+ """
18
+
19
+ import os
20
+ import math
21
+ import torch
22
+ import json
23
+ import torchvision
24
+ import numpy as np
25
+ import skimage.io
26
+
27
+ from PIL import Image
28
+ from tqdm import tqdm
29
+ from torchvision import transforms as pth_transforms
30
+
31
+ # Image transformation applied to all images
32
+ transform = pth_transforms.Compose(
33
+ [
34
+ pth_transforms.ToTensor(),
35
+ pth_transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
36
+ ]
37
+ )
38
+
39
+ class ImageDataset:
40
+ def __init__(
41
+ self,
42
+ image_path
43
+ ):
44
+
45
+ self.image_path = image_path
46
+ self.name = image_path.split("/")[-1]
47
+
48
+ # Read the image
49
+ with open(image_path, "rb") as f:
50
+ img = Image.open(f)
51
+ img = img.convert("RGB")
52
+
53
+ # Build a dataloader
54
+ img = transform(img)
55
+ self.dataloader = [[img, image_path]]
56
+
57
+ def get_image_name(self, *args, **kwargs):
58
+ return self.image_path.split("/")[-1].split(".")[0]
59
+
60
+ def load_image(self, *args, **kwargs):
61
+ return skimage.io.imread(self.image_path)
62
+
63
+ class UODDataset:
64
+ def __init__(
65
+ self,
66
+ dataset_name,
67
+ dataset_set,
68
+ root_dir,
69
+ remove_hards:bool = False,
70
+ ):
71
+ """
72
+ Build the dataloader
73
+ """
74
+
75
+ self.dataset_name = dataset_name
76
+ self.set = dataset_set
77
+ self.root_dir = root_dir
78
+
79
+ if dataset_name == "VOC07":
80
+ self.root_path = f"{root_dir}/VOC2007"
81
+ self.year = "2007"
82
+ elif dataset_name == "VOC12":
83
+ self.root_path = f"{root_dir}/VOC2012"
84
+ self.year = "2012"
85
+ elif dataset_name == "COCO20k":
86
+ self.year = "2014"
87
+ self.root_path = f"{root_dir}/COCO/images/{dataset_set}{self.year}"
88
+ self.sel20k = 'data/coco_20k_filenames.txt'
89
+ # JSON file constructed based on COCO train2014 gt
90
+ self.all_annfile = f"{root_dir}/COCO/annotations/instances_train2014.json"
91
+ self.annfile = f"{root_dir}/instances_train2014_sel20k.json"
92
+ if not os.path.exists(self.annfile):
93
+ select_coco_20k(self.sel20k, self.all_annfile)
94
+ else:
95
+ raise ValueError("Unknown dataset.")
96
+
97
+ if not os.path.exists(self.root_path):
98
+ raise ValueError("Please follow the README to setup the datasets.")
99
+
100
+ self.name = f"{self.dataset_name}_{self.set}"
101
+
102
+ # Build the dataloader
103
+ if "VOC" in dataset_name:
104
+ self.dataloader = torchvision.datasets.VOCDetection(
105
+ self.root_path,
106
+ year=self.year,
107
+ image_set=self.set,
108
+ transform=transform,
109
+ download=False,
110
+ )
111
+ elif "COCO20k" == dataset_name:
112
+ self.dataloader = torchvision.datasets.CocoDetection(
113
+ self.root_path, annFile=self.annfile, transform=transform
114
+ )
115
+ else:
116
+ raise ValueError("Unknown dataset.")
117
+
118
+ # Set hards images that are not included
119
+ self.remove_hards = remove_hards
120
+ self.hards = []
121
+ if remove_hards:
122
+ self.name += f"-nohards"
123
+ self.hards = self.get_hards()
124
+ print(f"Nb images discarded {len(self.hards)}")
125
+
126
+ def __len__(self) -> int:
127
+ return len(self.dataloader)
128
+
129
+ def load_image(self, im_name):
130
+ """
131
+ Load the image corresponding to the im_name
132
+ """
133
+ if "VOC" in self.dataset_name:
134
+ image = skimage.io.imread(f"{self.root_dir}/VOC{self.year}/JPEGImages/{im_name}")
135
+ elif "COCO" in self.dataset_name:
136
+ im_path = self.path_20k[self.sel_20k.index(im_name)]
137
+ image = skimage.io.imread(f"{self.root_dir}/COCO/images/{im_path}")
138
+ else:
139
+ raise ValueError("Unkown dataset.")
140
+ return image
141
+
142
+ def get_image_name(self, inp):
143
+ """
144
+ Return the image name
145
+ """
146
+ if "VOC" in self.dataset_name:
147
+ im_name = inp["annotation"]["filename"]
148
+ elif "COCO" in self.dataset_name:
149
+ im_name = str(inp[0]["image_id"])
150
+
151
+ return im_name
152
+
153
+ def extract_gt(self, targets, im_name):
154
+ if "VOC" in self.dataset_name:
155
+ return extract_gt_VOC(targets, remove_hards=self.remove_hards)
156
+ elif "COCO" in self.dataset_name:
157
+ return extract_gt_COCO(targets, remove_iscrowd=True)
158
+ else:
159
+ raise ValueError("Unknown dataset")
160
+
161
+ def extract_classes(self):
162
+ if "VOC" in self.dataset_name:
163
+ cls_path = f"classes_{self.set}_{self.year}.txt"
164
+ elif "COCO" in self.dataset_name:
165
+ cls_path = f"classes_{self.dataset}_{self.set}_{self.year}.txt"
166
+
167
+ # Load if exists
168
+ if os.path.exists(cls_path):
169
+ all_classes = []
170
+ with open(cls_path, "r") as f:
171
+ for line in f:
172
+ all_classes.append(line.strip())
173
+ else:
174
+ print("Extract all classes from the dataset")
175
+ if "VOC" in self.dataset_name:
176
+ all_classes = self.extract_classes_VOC()
177
+ elif "COCO" in self.dataset_name:
178
+ all_classes = self.extract_classes_COCO()
179
+
180
+ with open(cls_path, "w") as f:
181
+ for s in all_classes:
182
+ f.write(str(s) + "\n")
183
+
184
+ return all_classes
185
+
186
+ def extract_classes_VOC(self):
187
+ all_classes = []
188
+ for im_id, inp in enumerate(tqdm(self.dataloader)):
189
+ objects = inp[1]["annotation"]["object"]
190
+
191
+ for o in range(len(objects)):
192
+ if objects[o]["name"] not in all_classes:
193
+ all_classes.append(objects[o]["name"])
194
+
195
+ return all_classes
196
+
197
+ def extract_classes_COCO(self):
198
+ all_classes = []
199
+ for im_id, inp in enumerate(tqdm(self.dataloader)):
200
+ objects = inp[1]
201
+
202
+ for o in range(len(objects)):
203
+ if objects[o]["category_id"] not in all_classes:
204
+ all_classes.append(objects[o]["category_id"])
205
+
206
+ return all_classes
207
+
208
+ def get_hards(self):
209
+ hard_path = "datasets/hard_%s_%s_%s.txt" % (self.dataset_name, self.set, self.year)
210
+ if os.path.exists(hard_path):
211
+ hards = []
212
+ with open(hard_path, "r") as f:
213
+ for line in f:
214
+ hards.append(int(line.strip()))
215
+ else:
216
+ print("Discover hard images that should be discarded")
217
+
218
+ if "VOC" in self.dataset_name:
219
+ # set the hards
220
+ hards = discard_hard_voc(self.dataloader)
221
+
222
+ with open(hard_path, "w") as f:
223
+ for s in hards:
224
+ f.write(str(s) + "\n")
225
+
226
+ return hards
227
+
228
+
229
+ def discard_hard_voc(dataloader):
230
+ hards = []
231
+ for im_id, inp in enumerate(tqdm(dataloader)):
232
+ objects = inp[1]["annotation"]["object"]
233
+ nb_obj = len(objects)
234
+
235
+ hard = np.zeros(nb_obj)
236
+ for i, o in enumerate(range(nb_obj)):
237
+ hard[i] = (
238
+ 1
239
+ if (objects[o]["truncated"] == "1" or objects[o]["difficult"] == "1")
240
+ else 0
241
+ )
242
+
243
+ # all images with only truncated or difficult objects
244
+ if np.sum(hard) == nb_obj:
245
+ hards.append(im_id)
246
+ return hards
247
+
248
+
249
+ def extract_gt_COCO(targets, remove_iscrowd=True):
250
+ objects = targets
251
+ nb_obj = len(objects)
252
+
253
+ gt_bbxs = []
254
+ gt_clss = []
255
+ for o in range(nb_obj):
256
+ # Remove iscrowd boxes
257
+ if remove_iscrowd and objects[o]["iscrowd"] == 1:
258
+ continue
259
+ gt_cls = objects[o]["category_id"]
260
+ gt_clss.append(gt_cls)
261
+ bbx = objects[o]["bbox"]
262
+ x1y1x2y2 = [bbx[0], bbx[1], bbx[0] + bbx[2], bbx[1] + bbx[3]]
263
+ x1y1x2y2 = [int(round(x)) for x in x1y1x2y2]
264
+ gt_bbxs.append(x1y1x2y2)
265
+
266
+ return np.asarray(gt_bbxs), gt_clss
267
+
268
+
269
+ def extract_gt_VOC(targets, remove_hards=False):
270
+ objects = targets["annotation"]["object"]
271
+ nb_obj = len(objects)
272
+
273
+ gt_bbxs = []
274
+ gt_clss = []
275
+ for o in range(nb_obj):
276
+ if remove_hards and (
277
+ objects[o]["truncated"] == "1" or objects[o]["difficult"] == "1"
278
+ ):
279
+ continue
280
+ gt_cls = objects[o]["name"]
281
+ gt_clss.append(gt_cls)
282
+ obj = objects[o]["bndbox"]
283
+ x1y1x2y2 = [
284
+ int(obj["xmin"]),
285
+ int(obj["ymin"]),
286
+ int(obj["xmax"]),
287
+ int(obj["ymax"]),
288
+ ]
289
+ # Original annotations are integers in the range [1, W or H]
290
+ # Assuming they mean 1-based pixel indices (inclusive),
291
+ # a box with annotation (xmin=1, xmax=W) covers the whole image.
292
+ # In coordinate space this is represented by (xmin=0, xmax=W)
293
+ x1y1x2y2[0] -= 1
294
+ x1y1x2y2[1] -= 1
295
+ gt_bbxs.append(x1y1x2y2)
296
+
297
+ return np.asarray(gt_bbxs), gt_clss
298
+
299
+
300
+ def bbox_iou(box1, box2, x1y1x2y2=True, GIoU=False, DIoU=False, CIoU=False, eps=1e-7):
301
+ # https://github.com/ultralytics/yolov5/blob/develop/utils/general.py
302
+ # Returns the IoU of box1 to box2. box1 is 4, box2 is nx4
303
+ box2 = box2.T
304
+
305
+ # Get the coordinates of bounding boxes
306
+ if x1y1x2y2: # x1, y1, x2, y2 = box1
307
+ b1_x1, b1_y1, b1_x2, b1_y2 = box1[0], box1[1], box1[2], box1[3]
308
+ b2_x1, b2_y1, b2_x2, b2_y2 = box2[0], box2[1], box2[2], box2[3]
309
+ else: # transform from xywh to xyxy
310
+ b1_x1, b1_x2 = box1[0] - box1[2] / 2, box1[0] + box1[2] / 2
311
+ b1_y1, b1_y2 = box1[1] - box1[3] / 2, box1[1] + box1[3] / 2
312
+ b2_x1, b2_x2 = box2[0] - box2[2] / 2, box2[0] + box2[2] / 2
313
+ b2_y1, b2_y2 = box2[1] - box2[3] / 2, box2[1] + box2[3] / 2
314
+
315
+ # Intersection area
316
+ inter = (torch.min(b1_x2, b2_x2) - torch.max(b1_x1, b2_x1)).clamp(0) * (
317
+ torch.min(b1_y2, b2_y2) - torch.max(b1_y1, b2_y1)
318
+ ).clamp(0)
319
+
320
+ # Union Area
321
+ w1, h1 = b1_x2 - b1_x1, b1_y2 - b1_y1 + eps
322
+ w2, h2 = b2_x2 - b2_x1, b2_y2 - b2_y1 + eps
323
+ union = w1 * h1 + w2 * h2 - inter + eps
324
+
325
+ iou = inter / union
326
+ if GIoU or DIoU or CIoU:
327
+ cw = torch.max(b1_x2, b2_x2) - torch.min(
328
+ b1_x1, b2_x1
329
+ ) # convex (smallest enclosing box) width
330
+ ch = torch.max(b1_y2, b2_y2) - torch.min(b1_y1, b2_y1) # convex height
331
+ if CIoU or DIoU: # Distance or Complete IoU https://arxiv.org/abs/1911.08287v1
332
+ c2 = cw ** 2 + ch ** 2 + eps # convex diagonal squared
333
+ rho2 = (
334
+ (b2_x1 + b2_x2 - b1_x1 - b1_x2) ** 2
335
+ + (b2_y1 + b2_y2 - b1_y1 - b1_y2) ** 2
336
+ ) / 4 # center distance squared
337
+ if DIoU:
338
+ return iou - rho2 / c2 # DIoU
339
+ elif (
340
+ CIoU
341
+ ): # https://github.com/Zzh-tju/DIoU-SSD-pytorch/blob/master/utils/box/box_utils.py#L47
342
+ v = (4 / math.pi ** 2) * torch.pow(
343
+ torch.atan(w2 / h2) - torch.atan(w1 / h1), 2
344
+ )
345
+ with torch.no_grad():
346
+ alpha = v / (v - iou + (1 + eps))
347
+ return iou - (rho2 / c2 + v * alpha) # CIoU
348
+ else: # GIoU https://arxiv.org/pdf/1902.09630.pdf
349
+ c_area = cw * ch + eps # convex area
350
+ return iou - (c_area - union) / c_area # GIoU
351
+ else:
352
+ return iou # IoU
353
+
354
+ def select_coco_20k(sel_file, all_annotations_file):
355
+ print('Building COCO 20k dataset.')
356
+
357
+ # load all annotations
358
+ with open(all_annotations_file, "r") as f:
359
+ train2014 = json.load(f)
360
+
361
+ # load selected images
362
+ with open(sel_file, "r") as f:
363
+ sel_20k = f.readlines()
364
+ sel_20k = [s.replace("\n", "") for s in sel_20k]
365
+ im20k = [str(int(s.split("_")[-1].split(".")[0])) for s in sel_20k]
366
+
367
+ new_anno = []
368
+ new_images = []
369
+
370
+ for i in tqdm(im20k):
371
+ new_anno.extend(
372
+ [a for a in train2014["annotations"] if a["image_id"] == int(i)]
373
+ )
374
+ new_images.extend([a for a in train2014["images"] if a["id"] == int(i)])
375
+
376
+ train2014_20k = {}
377
+ train2014_20k["images"] = new_images
378
+ train2014_20k["annotations"] = new_anno
379
+ train2014_20k["categories"] = train2014["categories"]
380
+
381
+ with open("datasets/instances_train2014_sel20k.json", "w") as outfile:
382
+ json.dump(train2014_20k, outfile)
383
+
384
+ print('Done.')
datasets/utils.py ADDED
@@ -0,0 +1,44 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import torch
3
+ from PIL import Image
4
+ from torchvision import transforms as T
5
+
6
+ NORMALIZE = T.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))
7
+
8
+ class GaussianBlur:
9
+ """
10
+ Code borrowed from SelfMask: https://github.com/NoelShin/selfmask
11
+ """
12
+
13
+ # Implements Gaussian blur as described in the SimCLR paper
14
+ def __init__(self, kernel_size: float, min: float = 0.1, max: float = 2.0) -> None:
15
+ self.min = min
16
+ self.max = max
17
+ # kernel size is set to be 10% of the image height/width
18
+ self.kernel_size = kernel_size
19
+
20
+ def __call__(self, sample: Image.Image, random_gaussian_blur_p: float):
21
+ sample = np.array(sample)
22
+
23
+ # blur the image with a 50% chance
24
+ prob = np.random.random_sample()
25
+
26
+ if prob < 0.5:
27
+ import cv2
28
+
29
+ sigma = (self.max - self.min) * np.random.random_sample() + self.min
30
+ sample = cv2.GaussianBlur(
31
+ sample, (self.kernel_size, self.kernel_size), sigma
32
+ )
33
+ return sample
34
+
35
+
36
+ def unnormalize(image, mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)):
37
+ """
38
+ Code borrowed from STEGO: https://github.com/mhamilton723/STEGO
39
+ """
40
+ image2 = torch.clone(image)
41
+ for t, m, s in zip(image2, mean, std):
42
+ t.mul_(s).add_(m)
43
+
44
+ return image2
evaluation/__init__.py ADDED
File without changes
evaluation/metrics/__init__.py ADDED
File without changes
evaluation/metrics/average_meter.py ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Code borrowed from SelfMask: https://github.com/NoelShin/selfmask
3
+ """
4
+
5
+ class AverageMeter(object):
6
+ """Computes and stores the average and current value"""
7
+
8
+ def __init__(self):
9
+ self.reset()
10
+
11
+ def reset(self):
12
+ self.val = 0
13
+ self.avg = 0
14
+ self.sum = 0
15
+ self.count = 0
16
+
17
+ def update(self, val, n: int):
18
+ self.val = val
19
+ self.sum += val * n
20
+ self.count += n
21
+ self.avg = self.sum / self.count
evaluation/metrics/f_measure.py ADDED
@@ -0,0 +1,111 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Code borrowed from SelfMask: https://github.com/NoelShin/selfmask
3
+ """
4
+
5
+ import torch
6
+
7
+ class FMeasure:
8
+ def __init__(
9
+ self,
10
+ default_thres: float = 0.5,
11
+ beta_square: float = 0.3,
12
+ n_bins: int = 255,
13
+ eps: float = 1e-7,
14
+ ):
15
+ """
16
+ :param default_thres: a hyperparameter for F-measure that is used to binarize a predicted mask. Default: 0.5
17
+ :param beta_square: a hyperparameter for F-measure. Default: 0.3
18
+ :param n_bins: the number of thresholds that will be tested for F-max. Default: 255
19
+ :param eps: a small value for numerical stability
20
+ """
21
+
22
+ self.beta_square = beta_square
23
+ self.default_thres = default_thres
24
+ self.eps = eps
25
+ self.n_bins = n_bins
26
+
27
+ def _compute_precision_recall(
28
+ self, binary_pred_mask: torch.Tensor, gt_mask: torch.Tensor
29
+ ) -> torch.Tensor:
30
+ """
31
+ :param binary_pred_mask: (B x H x W) or (H x W)
32
+ :param gt_mask: (B x H x W) or (H x W), should be the same with binary_pred_mask
33
+ """
34
+ tp = torch.logical_and(binary_pred_mask, gt_mask).sum(dim=(-1, -2))
35
+ tp_fp = binary_pred_mask.sum(dim=(-1, -2))
36
+ tp_fn = gt_mask.sum(dim=(-1, -2))
37
+
38
+ prec = tp / (tp_fp + self.eps)
39
+ recall = tp / (tp_fn + self.eps)
40
+ return prec, recall
41
+
42
+ def _compute_f_measure(
43
+ self,
44
+ pred_mask: torch.Tensor,
45
+ gt_mask: torch.Tensor,
46
+ thresholds: torch.Tensor = None,
47
+ ) -> torch.Tensor:
48
+ if thresholds is None:
49
+ binary_pred_mask = pred_mask > self.default_thres
50
+ else:
51
+ binary_pred_mask = pred_mask > thresholds
52
+
53
+ prec, recall = self._compute_precision_recall(binary_pred_mask, gt_mask)
54
+ f_measure = ((1 + (self.beta_square**2)) * prec * recall) / (
55
+ (self.beta_square**2) * prec + recall + self.eps
56
+ )
57
+ return f_measure.cpu()
58
+
59
+ def _compute_f_max(
60
+ self, pred_mask: torch.Tensor, gt_mask: torch.Tensor
61
+ ) -> torch.Tensor:
62
+ """Compute self.n_bins + 1 F-measures, each of which has a different threshold, then return the maximum
63
+ F-measure among them.
64
+
65
+ :param pred_mask: (H x W)
66
+ :param gt_mask: (H x W)
67
+ """
68
+
69
+ # pred_masks, gt_masks: H x W -> self.n_bins x H x W
70
+ pred_masks = pred_mask.unsqueeze(dim=0).repeat(self.n_bins, 1, 1)
71
+ gt_masks = gt_mask.unsqueeze(dim=0).repeat(self.n_bins, 1, 1)
72
+
73
+ # thresholds: self.n_bins x 1 x 1
74
+ thresholds = (
75
+ torch.arange(0, 1, 1 / self.n_bins)
76
+ .view(self.n_bins, 1, 1)
77
+ .to(pred_masks.device)
78
+ )
79
+
80
+ # f_measures: self.n_bins
81
+ f_measures = self._compute_f_measure(pred_masks, gt_masks, thresholds)
82
+ return torch.max(f_measures).cpu(), f_measures
83
+
84
+ def _compute_f_mean(
85
+ self,
86
+ pred_mask: torch.Tensor,
87
+ gt_mask: torch.Tensor,
88
+ ) -> torch.Tensor:
89
+ adaptive_thres = 2 * pred_mask.mean(dim=(-1, -2), keepdim=True)
90
+ binary_pred_mask = pred_mask > adaptive_thres
91
+
92
+ prec, recall = self._compute_precision_recall(binary_pred_mask, gt_mask)
93
+ f_mean = ((1 + (self.beta_square**2)) * prec * recall) / (
94
+ (self.beta_square**2) * prec + recall + self.eps
95
+ )
96
+ return f_mean.cpu()
97
+
98
+ def __call__(self, pred_mask: torch.Tensor, gt_mask: torch.Tensor) -> dict:
99
+ """
100
+ :param pred_mask: (H x W) a normalized prediction mask with values in [0, 1]
101
+ :param gt_mask: (H x W) a binary ground truth mask with values in {0, 1}
102
+ :return: a dictionary with keys being "f_measure" and "f_max" and values being the respective values.
103
+ """
104
+ outputs: dict = dict()
105
+ for k in ("f_measure", "f_mean"):
106
+ outputs.update({k: getattr(self, f"_compute_{k}")(pred_mask, gt_mask)})
107
+
108
+ f_max_, all_f = self._compute_f_max(pred_mask, gt_mask)
109
+ outputs["f_max"] = f_max_
110
+ outputs["all_f"] = all_f # List of all f values for all thresholds
111
+ return outputs
evaluation/metrics/iou.py ADDED
@@ -0,0 +1,37 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Code adapted from SelfMask: https://github.com/NoelShin/selfmask
3
+ """
4
+
5
+ from typing import Optional, Union
6
+
7
+ import numpy as np
8
+ import torch
9
+
10
+
11
+ def compute_iou(
12
+ pred_mask: Union[np.ndarray, torch.Tensor],
13
+ gt_mask: Union[np.ndarray, torch.Tensor],
14
+ threshold: Optional[float] = 0.5,
15
+ eps: float = 1e-7,
16
+ ) -> Union[np.ndarray, torch.Tensor]:
17
+ """
18
+ :param pred_mask: (B x H x W) or (H x W)
19
+ :param gt_mask: (B x H x W) or (H x W), same shape with pred_mask
20
+ :param threshold: a binarization threshold
21
+ :param eps: a small value for computational stability
22
+ :return: (B) or (1)
23
+ """
24
+ assert pred_mask.shape == gt_mask.shape, f"{pred_mask.shape} != {gt_mask.shape}"
25
+ # assert 0. <= pred_mask.to(torch.float32).min() and pred_mask.max().to(torch.float32) <= 1., f"{pred_mask.min(), pred_mask.max()}"
26
+
27
+ if threshold is not None:
28
+ pred_mask = pred_mask > threshold
29
+ if isinstance(pred_mask, np.ndarray):
30
+ intersection = np.logical_and(pred_mask, gt_mask).sum(axis=(-1, -2))
31
+ union = np.logical_or(pred_mask, gt_mask).sum(axis=(-1, -2))
32
+ ious = intersection / (union + eps)
33
+ else:
34
+ intersection = torch.logical_and(pred_mask, gt_mask).sum(dim=(-1, -2))
35
+ union = torch.logical_or(pred_mask, gt_mask).sum(dim=(-1, -2))
36
+ ious = (intersection / (union + eps)).cpu()
37
+ return ious
evaluation/metrics/mae.py ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Code borrowed from SelfMask: https://github.com/NoelShin/selfmask
3
+ """
4
+
5
+ import torch
6
+
7
+ def compute_mae(pred_mask: torch.Tensor, gt_mask: torch.Tensor) -> torch.Tensor:
8
+ """
9
+ :param pred_mask: (H x W) or (B x H x W) a normalized prediction mask with values in [0, 1]
10
+ :param gt_mask: (H x W) or (B x H x W) a binary ground truth mask with values in {0, 1}
11
+ """
12
+ return torch.mean(
13
+ torch.abs(pred_mask - gt_mask.to(torch.float32)), dim=(-1, -2)
14
+ ).cpu()
evaluation/metrics/pixel_acc.py ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Code borrowed from SelfMask: https://github.com/NoelShin/selfmask
3
+ """
4
+
5
+ from typing import Optional
6
+
7
+ import torch
8
+
9
+
10
+ def compute_pixel_accuracy(
11
+ pred_mask: torch.Tensor, gt_mask: torch.Tensor, threshold: Optional[float] = 0.5
12
+ ) -> torch.Tensor:
13
+ """
14
+ :param pred_mask: (H x W) or (B x H x W) a normalized prediction mask with values in [0, 1]
15
+ :param gt_mask: (H x W) or (B x H x W) a binary ground truth mask with values in {0, 1}
16
+ """
17
+ if threshold is not None:
18
+ binary_pred_mask = pred_mask > threshold
19
+ else:
20
+ binary_pred_mask = pred_mask
21
+ return (binary_pred_mask == gt_mask).to(torch.float32).mean(dim=(-1, -2)).cpu()
evaluation/metrics/s_measure.py ADDED
@@ -0,0 +1,126 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # code borrowed from https://github.com/Hanqer/Evaluate-SOD/blob/master/evaluator.py
2
+ import numpy as np
3
+ import torch
4
+
5
+
6
+ class SMeasure:
7
+ def __init__(self, alpha: float = 0.5):
8
+ self.alpha: float = alpha
9
+ self.cuda: bool = True
10
+
11
+ def _centroid(self, gt):
12
+ rows, cols = gt.size()[-2:]
13
+ gt = gt.view(rows, cols)
14
+ if gt.sum() == 0:
15
+ if self.cuda:
16
+ X = torch.eye(1).cuda() * round(cols / 2)
17
+ Y = torch.eye(1).cuda() * round(rows / 2)
18
+ else:
19
+ X = torch.eye(1) * round(cols / 2)
20
+ Y = torch.eye(1) * round(rows / 2)
21
+ else:
22
+ total = gt.sum()
23
+ if self.cuda:
24
+ i = torch.from_numpy(np.arange(0, cols)).cuda().float()
25
+ j = torch.from_numpy(np.arange(0, rows)).cuda().float()
26
+ else:
27
+ i = torch.from_numpy(np.arange(0, cols)).float()
28
+ j = torch.from_numpy(np.arange(0, rows)).float()
29
+ X = torch.round((gt.sum(dim=0) * i).sum() / total)
30
+ Y = torch.round((gt.sum(dim=1) * j).sum() / total)
31
+ return X.long(), Y.long()
32
+
33
+ def _ssim(self, pred, gt):
34
+ gt = gt.float()
35
+ h, w = pred.size()[-2:]
36
+ N = h * w
37
+ x = pred.mean()
38
+ y = gt.mean()
39
+ sigma_x2 = ((pred - x) * (pred - x)).sum() / (N - 1 + 1e-20)
40
+ sigma_y2 = ((gt - y) * (gt - y)).sum() / (N - 1 + 1e-20)
41
+ sigma_xy = ((pred - x) * (gt - y)).sum() / (N - 1 + 1e-20)
42
+
43
+ aplha = 4 * x * y * sigma_xy
44
+ beta = (x * x + y * y) * (sigma_x2 + sigma_y2)
45
+
46
+ if aplha != 0:
47
+ Q = aplha / (beta + 1e-20)
48
+ elif aplha == 0 and beta == 0:
49
+ Q = 1.0
50
+ else:
51
+ Q = 0
52
+ return Q
53
+
54
+ def _object(self, pred, gt):
55
+ temp = pred[gt == 1]
56
+ x = temp.mean()
57
+ sigma_x = temp.std()
58
+ score = 2.0 * x / (x * x + 1.0 + sigma_x + 1e-20)
59
+
60
+ return score
61
+
62
+ def _s_object(self, pred, gt):
63
+ fg = torch.where(gt == 0, torch.zeros_like(pred), pred)
64
+ bg = torch.where(gt == 1, torch.zeros_like(pred), 1 - pred)
65
+ o_fg = self._object(fg, gt)
66
+ o_bg = self._object(bg, 1 - gt)
67
+ u = gt.mean()
68
+ Q = u * o_fg + (1 - u) * o_bg
69
+ return Q
70
+
71
+ def _divide_gt(self, gt, X, Y):
72
+ h, w = gt.size()[-2:]
73
+ area = h * w
74
+ gt = gt.view(h, w)
75
+ LT = gt[:Y, :X]
76
+ RT = gt[:Y, X:w]
77
+ LB = gt[Y:h, :X]
78
+ RB = gt[Y:h, X:w]
79
+ X = X.float()
80
+ Y = Y.float()
81
+ w1 = X * Y / area
82
+ w2 = (w - X) * Y / area
83
+ w3 = X * (h - Y) / area
84
+ w4 = 1 - w1 - w2 - w3
85
+ return LT, RT, LB, RB, w1, w2, w3, w4
86
+
87
+ def _divide_prediction(self, pred, X, Y):
88
+ h, w = pred.size()[-2:]
89
+ pred = pred.view(h, w)
90
+ LT = pred[:Y, :X]
91
+ RT = pred[:Y, X:w]
92
+ LB = pred[Y:h, :X]
93
+ RB = pred[Y:h, X:w]
94
+ return LT, RT, LB, RB
95
+
96
+ def _s_region(self, pred, gt):
97
+ X, Y = self._centroid(gt)
98
+ gt1, gt2, gt3, gt4, w1, w2, w3, w4 = self._divide_gt(gt, X, Y)
99
+ p1, p2, p3, p4 = self._divide_prediction(pred, X, Y)
100
+ Q1 = self._ssim(p1, gt1)
101
+ Q2 = self._ssim(p2, gt2)
102
+ Q3 = self._ssim(p3, gt3)
103
+ Q4 = self._ssim(p4, gt4)
104
+ Q = w1 * Q1 + w2 * Q2 + w3 * Q3 + w4 * Q4
105
+ # print(Q)
106
+ return Q
107
+
108
+ def __call__(self, pred_mask: torch.Tensor, gt_mask: torch.Tensor):
109
+ assert pred_mask.shape == gt_mask.shape
110
+ y = gt_mask.mean()
111
+ if y == 0:
112
+ x = pred_mask.mean()
113
+ Q = 1.0 - x
114
+ elif y == 1:
115
+ x = pred_mask.mean()
116
+ Q = x
117
+ else:
118
+ gt_mask[gt_mask >= 0.5] = 1
119
+ gt_mask[gt_mask < 0.5] = 0
120
+ # print(self._S_object(pred, gt), self._S_region(pred, gt))
121
+ Q = self.alpha * self._s_object(pred_mask, gt_mask) + (
122
+ 1 - self.alpha
123
+ ) * self._s_region(pred_mask, gt_mask)
124
+ if Q.item() < 0:
125
+ Q = torch.FloatTensor([0.0])
126
+ return Q.item()
evaluation/saliency.py ADDED
@@ -0,0 +1,290 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2022 - Valeo Comfort and Driving Assistance - valeo.ai
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import torch
16
+ import numpy as np
17
+ import torch.nn as nn
18
+ import torch.nn.functional as F
19
+
20
+ from tqdm import tqdm
21
+ from scipy import ndimage
22
+
23
+ from evaluation.metrics.average_meter import AverageMeter
24
+ from evaluation.metrics.f_measure import FMeasure
25
+ from evaluation.metrics.iou import compute_iou
26
+ from evaluation.metrics.mae import compute_mae
27
+ from evaluation.metrics.pixel_acc import compute_pixel_accuracy
28
+ from evaluation.metrics.s_measure import SMeasure
29
+
30
+ from misc import batch_apply_bilateral_solver
31
+
32
+
33
+ @torch.no_grad()
34
+ def write_metric_tf(
35
+ writer,
36
+ metrics,
37
+ n_iter = -1,
38
+ name = ""
39
+ ):
40
+ writer.add_scalar(
41
+ f"Validation/{name}iou_pred",
42
+ metrics["ious"].avg,
43
+ n_iter,
44
+ )
45
+ writer.add_scalar(
46
+ f"Validation/{name}acc_pred",
47
+ metrics["pixel_accs"].avg,
48
+ n_iter,
49
+ )
50
+ writer.add_scalar(
51
+ f"Validation/{name}f_max",
52
+ metrics["f_maxs"].avg,
53
+ n_iter,
54
+ )
55
+
56
+ @torch.no_grad()
57
+ def eval_batch(
58
+ batch_gt_masks,
59
+ batch_pred_masks,
60
+ metrics_res={},
61
+ reset=False
62
+ ):
63
+ """
64
+ Evaluation code adapted from SelfMask: https://github.com/NoelShin/selfmask
65
+ """
66
+
67
+ f_values = {}
68
+ # Keep track of f_values for each threshold
69
+ for i in range(255): # should equal n_bins in metrics/f_measure.py
70
+ f_values[i] = AverageMeter()
71
+
72
+ if metrics_res == {}:
73
+ metrics_res["f_scores"] = AverageMeter()
74
+ metrics_res["f_maxs"] = AverageMeter()
75
+ metrics_res["f_maxs_fixed"] = AverageMeter()
76
+ metrics_res["f_means"] = AverageMeter()
77
+ metrics_res["maes"] = AverageMeter()
78
+ metrics_res["ious"] = AverageMeter()
79
+ metrics_res["pixel_accs"] = AverageMeter()
80
+ metrics_res["s_measures"] = AverageMeter()
81
+
82
+ if reset:
83
+ metrics_res["f_scores"].reset()
84
+ metrics_res["f_maxs"].reset()
85
+ metrics_res["f_maxs_fixed"].reset()
86
+ metrics_res["f_means"].reset()
87
+ metrics_res["maes"].reset()
88
+ metrics_res["ious"].reset()
89
+ metrics_res["pixel_accs"].reset()
90
+ metrics_res["s_measures"].reset()
91
+
92
+ # iterate over batch dimension
93
+ for _, (pred_mask, gt_mask) in enumerate(
94
+ zip(batch_pred_masks, batch_gt_masks)
95
+ ):
96
+ assert pred_mask.shape == gt_mask.shape, f"{pred_mask.shape} != {gt_mask.shape}"
97
+ assert len(pred_mask.shape) == len(gt_mask.shape) == 2
98
+ # Compute
99
+ # Binarize at 0.5 for IoU and pixel accuracy
100
+ binary_pred = (pred_mask > 0.5).float().squeeze()
101
+ iou = compute_iou(binary_pred, gt_mask)
102
+ f_measures = FMeasure()(pred_mask, gt_mask) # soft mask for F measure
103
+ mae = compute_mae(binary_pred, gt_mask)
104
+ pixel_acc = compute_pixel_accuracy(binary_pred, gt_mask)
105
+
106
+ # Update
107
+ metrics_res["ious"].update(val=iou.numpy(), n=1)
108
+ metrics_res["f_scores"].update(val=f_measures["f_measure"].numpy(), n=1)
109
+ metrics_res["f_maxs"].update(val=f_measures["f_max"].numpy(), n=1)
110
+ metrics_res["f_means"].update(val=f_measures["f_mean"].numpy(), n=1)
111
+ metrics_res["s_measures"].update(
112
+ val=SMeasure()(pred_mask=pred_mask, gt_mask=gt_mask.to(torch.float32)), n=1
113
+ )
114
+ metrics_res["maes"].update(val=mae.numpy(), n=1)
115
+ metrics_res["pixel_accs"].update(val=pixel_acc.numpy(), n=1)
116
+
117
+ # Keep track of f_values for each threshold
118
+ all_f = f_measures["all_f"].numpy()
119
+ for k, v in f_values.items():
120
+ v.update(val=all_f[k], n=1)
121
+ # Then compute the max for the f_max_fixed
122
+ metrics_res["f_maxs_fixed"].update(
123
+ val=np.max([v.avg for v in f_values.values()]), n=1
124
+ )
125
+
126
+ results = {}
127
+ # F-measure, F-max, F-mean, MAE, S-measure, IoU, pixel acc.
128
+ results["f_measure"] = metrics_res["f_scores"].avg
129
+ results["f_max"] = metrics_res["f_maxs"].avg
130
+ results["f_maxs_fixed"] = metrics_res["f_maxs_fixed"].avg
131
+ results["f_mean"] = metrics_res["f_means"].avg
132
+ results["s_measure"] = metrics_res["s_measures"].avg
133
+ results["mae"] = metrics_res["maes"].avg
134
+ results["iou"] = float(iou.numpy())
135
+ results["pixel_acc"] = metrics_res["pixel_accs"].avg
136
+
137
+ return results, metrics_res
138
+
139
+ def evaluate_saliency(
140
+ dataset,
141
+ model,
142
+ writer=None,
143
+ batch_size=1,
144
+ n_iter=-1,
145
+ apply_bilateral=False,
146
+ im_fullsize=True,
147
+ method="pred", # can also be "bkg",
148
+ apply_weights: bool = True,
149
+ evaluation_mode: str = 'single', # choices are ["single", "multi"]
150
+ ):
151
+
152
+ if im_fullsize:
153
+ # Change transformation
154
+ dataset.fullimg_mode()
155
+ batch_size = 1
156
+
157
+ valloader = torch.utils.data.DataLoader(
158
+ dataset,
159
+ batch_size=batch_size,
160
+ shuffle=False,
161
+ num_workers=2
162
+ )
163
+
164
+ sigmoid = nn.Sigmoid()
165
+
166
+ metrics_res = {}
167
+ metrics_res_bs = {}
168
+ valbar = tqdm(enumerate(valloader, 0), leave=None)
169
+ for i, data in valbar:
170
+ inputs, _, gt_labels, _ = data
171
+ inputs = inputs.to("cuda")
172
+ gt_labels = gt_labels.to("cuda").float()
173
+
174
+ # Forward step
175
+ with torch.no_grad():
176
+ preds, _, shape_f, att = model.forward_step(inputs, for_eval=True)
177
+
178
+ if method == "pred":
179
+ h, w = gt_labels.shape[-2:]
180
+ preds_up = F.interpolate(
181
+ preds, scale_factor=model.vit_patch_size, mode="bilinear", align_corners=False
182
+ )[..., :h, :w]
183
+ soft_preds = sigmoid(preds_up.detach()).squeeze(0)
184
+ preds_up = (
185
+ (sigmoid(preds_up.detach()) > 0.5).squeeze(0).float()
186
+ )
187
+
188
+ elif method == "bkg":
189
+ bkg_mask_pred = model.compute_background_batch(
190
+ att, shape_f,
191
+ apply_weights=apply_weights,
192
+ )
193
+ # Transform bkg detection to foreground detection
194
+ obj_mask = (
195
+ ~bkg_mask_pred.bool()
196
+ ).float() # Obj labels is inverse of bkg
197
+
198
+ # Fit predictions to image size
199
+ preds_up = F.interpolate(
200
+ obj_mask.unsqueeze(1),
201
+ gt_labels.shape[-2:],
202
+ mode="bilinear",
203
+ align_corners=False,
204
+ )
205
+ preds_up = (preds_up > 0.5).float()
206
+ soft_preds = preds_up # not soft actually
207
+
208
+ reset = True if i == 0 else False
209
+ if evaluation_mode == 'single':
210
+ labeled, nr_objects = ndimage.label(preds_up.squeeze().cpu().numpy())
211
+ if nr_objects == 0:
212
+ preds_up_one_cc = preds_up.squeeze()
213
+ print("nr_objects == 0")
214
+ else:
215
+ nb_pixel = [np.sum(labeled == i) for i in range(nr_objects + 1)]
216
+ pixel_order = np.argsort(nb_pixel)
217
+
218
+ cc = [torch.Tensor(labeled == i) for i in pixel_order]
219
+ cc = torch.stack(cc).cuda()
220
+
221
+ # Find CC set as background, here not necessarily the biggest
222
+ cc_background = (
223
+ (
224
+ (
225
+ (~(preds_up[None, :, :, :].bool())).float()
226
+ + cc[:, None, :, :].cuda()
227
+ )
228
+ > 1
229
+ ).sum(-1).sum(-1).argmax()
230
+ )
231
+ pixel_order = np.delete(
232
+ pixel_order, int(cc_background.cpu().numpy())
233
+ )
234
+
235
+ preds_up_one_cc = torch.Tensor(labeled == pixel_order[-1]).cuda()
236
+
237
+ _, metrics_res = eval_batch(
238
+ gt_labels,
239
+ preds_up_one_cc.unsqueeze(0),
240
+ metrics_res=metrics_res,
241
+ reset=reset,
242
+ )
243
+
244
+ if writer is not None:
245
+ write_metric_tf(writer, metrics_res, n_iter=n_iter, name=f"_{evaluation_mode}_")
246
+
247
+ elif evaluation_mode == 'multi':
248
+ # Eval without bilateral solver
249
+ _, metrics_res = eval_batch(
250
+ gt_labels,
251
+ soft_preds.unsqueeze(0) if len(soft_preds.shape) == 2 else soft_preds,
252
+ metrics_res=metrics_res,
253
+ reset=reset,
254
+ ) # soft preds needed for F beta measure
255
+
256
+ # Apply bilateral solver
257
+ preds_bs = None
258
+ if apply_bilateral:
259
+ get_all_cc = True if evaluation_mode == 'multi' else False
260
+ preds_bs, _ = batch_apply_bilateral_solver(data,
261
+ preds_up.detach(),
262
+ get_all_cc = get_all_cc
263
+ )
264
+
265
+ _, metrics_res_bs = eval_batch(
266
+ gt_labels,
267
+ preds_bs[None,:,:].float(),
268
+ metrics_res=metrics_res_bs,
269
+ reset=reset
270
+ )
271
+
272
+ if writer is not None:
273
+ write_metric_tf(writer, metrics_res_bs, n_iter=n_iter, name=f"_{evaluation_mode}-BS_")
274
+
275
+ bar_str = f"{dataset.name} | {evaluation_mode} mode | " \
276
+ f"F-max {metrics_res['f_maxs'].avg:.3f} " \
277
+ f"IoU {metrics_res['ious'].avg:.3f}, " \
278
+ f"PA {metrics_res['pixel_accs'].avg:.3f}"
279
+
280
+ if apply_bilateral:
281
+ bar_str += f" | with bilateral solver: " \
282
+ f"F-max {metrics_res_bs['f_maxs'].avg:.3f}, " \
283
+ f"IoU {metrics_res_bs['ious'].avg:.3f}, " \
284
+ f"PA. {metrics_res_bs['pixel_accs'].avg:.3f}"
285
+
286
+ valbar.set_description(bar_str)
287
+
288
+ # Go back to original transformation
289
+ if im_fullsize:
290
+ dataset.training_mode()
evaluation/uod.py ADDED
@@ -0,0 +1,118 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2021 - Valeo Comfort and Driving Assistance - Oriane Siméoni @ valeo.ai
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ """
16
+ Code adapted from previous method LOST: https://github.com/valeoai/LOST
17
+ """
18
+
19
+ import os
20
+ import time
21
+ import torch
22
+ import torch.nn as nn
23
+ import numpy as np
24
+
25
+ from tqdm import tqdm
26
+ from misc import bbox_iou, get_bbox_from_segmentation_labels
27
+
28
+
29
+ def evaluation_unsupervised_object_discovery(
30
+ dataset,
31
+ model,
32
+ evaluation_mode: str = 'single', # choices are ["single", "multi"]
33
+ output_dir:str = "outputs",
34
+ no_hards:bool = False,
35
+ ):
36
+
37
+ assert evaluation_mode == "single"
38
+
39
+ sigmoid = nn.Sigmoid()
40
+
41
+ # ----------------------------------------------------
42
+ # Loop over images
43
+ preds_dict = {}
44
+ cnt = 0
45
+ corloc = np.zeros(len(dataset.dataloader))
46
+
47
+ start_time = time.time()
48
+ pbar = tqdm(dataset.dataloader)
49
+ for im_id, inp in enumerate(pbar):
50
+
51
+ # ------------ IMAGE PROCESSING -------------------------------------------
52
+ img = inp[0]
53
+
54
+ init_image_size = img.shape
55
+
56
+ # Get the name of the image
57
+ im_name = dataset.get_image_name(inp[1])
58
+ # Pass in case of no gt boxes in the image
59
+ if im_name is None:
60
+ continue
61
+
62
+ # Padding the image with zeros to fit multiple of patch-size
63
+ size_im = (
64
+ img.shape[0],
65
+ int(np.ceil(img.shape[1] / model.vit_patch_size) * model.vit_patch_size),
66
+ int(np.ceil(img.shape[2] / model.vit_patch_size) * model.vit_patch_size),
67
+ )
68
+ paded = torch.zeros(size_im)
69
+ paded[:, : img.shape[1], : img.shape[2]] = img
70
+ img = paded
71
+
72
+ # # Move to gpu
73
+ img = img.cuda(non_blocking=True)
74
+
75
+ # Size for transformers
76
+ w_featmap = img.shape[-2] // model.vit_patch_size
77
+ h_featmap = img.shape[-1] // model.vit_patch_size
78
+
79
+ # ------------ GROUND-TRUTH -------------------------------------------
80
+ gt_bbxs, gt_cls = dataset.extract_gt(inp[1], im_name)
81
+
82
+ if gt_bbxs is not None:
83
+ # Discard images with no gt annotations
84
+ # Happens only in the case of VOC07 and VOC12
85
+ if gt_bbxs.shape[0] == 0 and no_hards:
86
+ continue
87
+
88
+ outputs = model.forward_step(img[None, :, :, :])
89
+ preds = (sigmoid(outputs[0].detach()) > 0.5).float().squeeze().cpu().numpy()
90
+
91
+ # get bbox
92
+ pred = get_bbox_from_segmentation_labels(
93
+ segmenter_predictions=preds,
94
+ scales=[model.vit_patch_size, model.vit_patch_size],
95
+ initial_image_size=init_image_size[1:],
96
+ )
97
+
98
+ # ------------ Visualizations -------------------------------------------
99
+ # Save the prediction
100
+ preds_dict[im_name] = pred
101
+
102
+
103
+ # Compare prediction to GT boxes
104
+ ious = bbox_iou(torch.from_numpy(pred), torch.from_numpy(gt_bbxs))
105
+
106
+ if torch.any(ious >= 0.5):
107
+ corloc[im_id] = 1
108
+
109
+ cnt += 1
110
+ if cnt % 50 == 0:
111
+ pbar.set_description(f"Found {int(np.sum(corloc))}/{cnt}")
112
+
113
+ # Evaluate
114
+ print(f"corloc: {100*np.sum(corloc)/cnt:.2f} ({int(np.sum(corloc))}/{cnt})")
115
+ result_file = os.path.join(output_dir, 'uod_results.txt')
116
+ with open(result_file, 'w') as f:
117
+ f.write('corloc,%.1f,,\n'%(100*np.sum(corloc)/cnt))
118
+ print('File saved at %s'%result_file)
main_found_evaluate.py ADDED
@@ -0,0 +1,122 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2022 - Valeo Comfort and Driving Assistance - Oriane Siméoni @ valeo.ai
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import argparse
16
+ from model import FoundModel
17
+ from misc import load_config
18
+ from datasets.datasets import build_dataset
19
+ from evaluation.saliency import evaluate_saliency
20
+ from evaluation.uod import evaluation_unsupervised_object_discovery
21
+
22
+ if __name__ == "__main__":
23
+ parser = argparse.ArgumentParser(
24
+ description = 'Evaluation of FOUND',
25
+ formatter_class=argparse.ArgumentDefaultsHelpFormatter
26
+ )
27
+ parser.add_argument(
28
+ "--eval-type",
29
+ type=str,
30
+ choices=["saliency", "uod"],
31
+ help="Evaluation type."
32
+ )
33
+ parser.add_argument(
34
+ "--dataset-eval",
35
+ type=str,
36
+ choices=["ECSSD", "DUT-OMRON", "DUTS-TEST", "VOC07", "VOC12", "COCO20k"],
37
+ help="Name of evaluation dataset."
38
+ )
39
+ parser.add_argument(
40
+ "--dataset-set-eval",
41
+ type=str,
42
+ default=None,
43
+ help="Set of the dataset."
44
+ )
45
+ parser.add_argument(
46
+ "--apply-bilateral",
47
+ action="store_true",
48
+ help="use bilateral solver."
49
+ )
50
+ parser.add_argument(
51
+ "--evaluation-mode",
52
+ type=str,
53
+ default="multi",
54
+ choices=["single", "multi"],
55
+ help="Type of evaluation."
56
+ )
57
+ parser.add_argument(
58
+ "--model-weights",
59
+ type=str,
60
+ default="data/weights/decoder_weights.pt",
61
+ )
62
+ parser.add_argument(
63
+ "--dataset-dir",
64
+ type=str,
65
+ default="/datasets_local",
66
+ )
67
+ parser.add_argument(
68
+ "--config",
69
+ type=str,
70
+ default="configs/found_DUTS-TR.yaml",
71
+ )
72
+ args = parser.parse_args()
73
+ print(args.__dict__)
74
+
75
+ # Configuration
76
+ config = load_config(args.config)
77
+
78
+ # ------------------------------------
79
+ # Load the model
80
+ model = FoundModel(vit_model=config.model["pre_training"],
81
+ vit_arch=config.model["arch"],
82
+ vit_patch_size=config.model["patch_size"],
83
+ enc_type_feats=config.found["feats"],
84
+ bkg_type_feats=config.found["feats"],
85
+ bkg_th=config.found["bkg_th"])
86
+ # Load weights
87
+ model.decoder_load_weights(args.model_weights)
88
+ model.eval()
89
+ print(f"Model {args.model_weights} loaded correctly.")
90
+
91
+ # ------------------------------------
92
+ # Build the validation set
93
+ val_dataset = build_dataset(
94
+ root_dir=args.dataset_dir,
95
+ dataset_name=args.dataset_eval,
96
+ dataset_set=args.dataset_set_eval,
97
+ for_eval=True,
98
+ evaluation_type=args.eval_type,
99
+ )
100
+ print(f"\nBuilding dataset {val_dataset.name} (#{len(val_dataset)} images)")
101
+
102
+ # ------------------------------------
103
+ # Training
104
+ print(f"\nStarted evaluation on {val_dataset.name}")
105
+ if args.eval_type == "saliency":
106
+ evaluate_saliency(
107
+ val_dataset,
108
+ model=model,
109
+ evaluation_mode=args.evaluation_mode,
110
+ apply_bilateral=args.apply_bilateral,
111
+ )
112
+ elif args.eval_type == "uod":
113
+ if args.apply_bilateral:
114
+ raise ValueError("Not implemented.")
115
+
116
+ evaluation_unsupervised_object_discovery(
117
+ val_dataset,
118
+ model=model,
119
+ evaluation_mode=args.evaluation_mode,
120
+ )
121
+ else:
122
+ raise ValueError("Other evaluation method to come.")
main_visualize.py ADDED
@@ -0,0 +1,99 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2022 - Valeo Comfort and Driving Assistance - Oriane Siméoni @ valeo.ai
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import os
16
+ import torch
17
+ import argparse
18
+ import torch.nn as nn
19
+ import torch.nn.functional as F
20
+ import matplotlib.pyplot as plt
21
+
22
+ from PIL import Image
23
+ from model import FoundModel
24
+ from misc import load_config
25
+ from torchvision import transforms as T
26
+
27
+ NORMALIZE = T.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))
28
+
29
+ if __name__ == "__main__":
30
+ parser = argparse.ArgumentParser(
31
+ description = 'Evaluation of FOUND',
32
+ formatter_class=argparse.ArgumentDefaultsHelpFormatter
33
+ )
34
+
35
+ parser.add_argument(
36
+ "--img-path", type=str, default="data/examples/VOC07_000007.jpg", help="Image path."
37
+ )
38
+ parser.add_argument(
39
+ "--model-weights", type=str, default="data/weights/decoder_weights.pt",
40
+ )
41
+ parser.add_argument(
42
+ "--config", type=str, default="configs/found_DUTS-TR.yaml",
43
+ )
44
+ parser.add_argument(
45
+ "--output-dir", type=str, default="outputs",
46
+ )
47
+ args = parser.parse_args()
48
+
49
+ # Saving dir
50
+ if not os.path.exists(args.output_dir):
51
+ os.makedirs(args.output_dir)
52
+
53
+ # Configuration
54
+ config = load_config(args.config)
55
+
56
+ # ------------------------------------
57
+ # Load the model
58
+ model = FoundModel(vit_model=config.model["pre_training"],
59
+ vit_arch=config.model["arch"],
60
+ vit_patch_size=config.model["patch_size"],
61
+ enc_type_feats=config.found["feats"],
62
+ bkg_type_feats=config.found["feats"],
63
+ bkg_th=config.found["bkg_th"])
64
+ # Load weights
65
+ model.decoder_load_weights(args.model_weights)
66
+ model.eval()
67
+ print(f"Model {args.model_weights} loaded correctly.")
68
+
69
+ # Load the image
70
+ with open(args.img_path, "rb") as f:
71
+ img = Image.open(f)
72
+ img = img.convert("RGB")
73
+
74
+ t = T.Compose([T.ToTensor(), NORMALIZE])
75
+ img_t = t(img)[None,:,:,:]
76
+ inputs = img_t.to("cuda")
77
+
78
+ # Forward step
79
+ with torch.no_grad():
80
+ preds, _, shape_f, att = model.forward_step(inputs, for_eval=True)
81
+
82
+ # Apply FOUND
83
+ sigmoid = nn.Sigmoid()
84
+ h, w = img_t.shape[-2:]
85
+ preds_up = F.interpolate(
86
+ preds, scale_factor=model.vit_patch_size, mode="bilinear", align_corners=False
87
+ )[..., :h, :w]
88
+ preds_up = (
89
+ (sigmoid(preds_up.detach()) > 0.5).squeeze(0).float()
90
+ )
91
+
92
+ plt.figure()
93
+ plt.imshow(img)
94
+ plt.imshow(preds_up.cpu().squeeze().numpy(), 'gray', interpolation='none', alpha=0.5)
95
+ plt.axis('off')
96
+ img_name = args.img_path
97
+ img_name = img_name.split('/')[-1].split('.')[0]
98
+ plt.savefig(os.path.join(args.output_dir, f'{img_name}-found.png'), bbox_inches='tight', pad_inches=0)
99
+ plt.close()
misc.py ADDED
@@ -0,0 +1,254 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import re
2
+ import os
3
+ import cv2
4
+ import yaml
5
+ import math
6
+ import random
7
+ import scipy.ndimage
8
+ import numpy as np
9
+
10
+ import torch
11
+ import torch.nn.functional as F
12
+
13
+ from typing import List
14
+ from torchvision import transforms as T
15
+
16
+ from bilateral_solver import bilateral_solver_output
17
+
18
+
19
+ loader = yaml.SafeLoader
20
+ loader.add_implicit_resolver(
21
+ u'tag:yaml.org,2002:float',
22
+ re.compile(u'''^(?:
23
+ [-+]?(?:[0-9][0-9_]*)\\.[0-9_]*(?:[eE][-+]?[0-9]+)?
24
+ |[-+]?(?:[0-9][0-9_]*)(?:[eE][-+]?[0-9]+)
25
+ |\\.[0-9_]+(?:[eE][-+][0-9]+)?
26
+ |[-+]?[0-9][0-9_]*(?::[0-5]?[0-9])+\\.[0-9_]*
27
+ |[-+]?\\.(?:inf|Inf|INF)
28
+ |\\.(?:nan|NaN|NAN))$''', re.X),
29
+ list(u'-+0123456789.'))
30
+
31
+ class Struct:
32
+ def __init__(self, **entries):
33
+ self.__dict__.update(entries)
34
+
35
+ def load_config(config_file):
36
+ with open(config_file, errors='ignore') as f:
37
+ # conf = yaml.safe_load(f) # load config
38
+ conf = yaml.load(f, Loader=loader)
39
+ print('hyperparameters: ' + ', '.join(f'{k}={v}' for k, v in conf.items()))
40
+
41
+ #TODO yaml_save(save_dir / 'config.yaml', conf)
42
+ return Struct(**conf)
43
+
44
+ def set_seed(seed: int) -> None:
45
+ """
46
+ Set all seeds to make results reproducible
47
+ """
48
+ # env
49
+ os.environ["PYTHONHASHSEED"] = str(seed)
50
+
51
+ # python
52
+ random.seed(seed)
53
+
54
+ # numpy
55
+ np.random.seed(seed)
56
+
57
+ # torch
58
+ torch.manual_seed(seed)
59
+ torch.cuda.manual_seed_all(seed)
60
+ torch.backends.cudnn.deterministic = True
61
+
62
+ def IoU(mask1, mask2):
63
+ """
64
+ Code adapted from TokenCut: https://github.com/YangtaoWANG95/TokenCut
65
+ """
66
+ mask1, mask2 = (mask1 > 0.5).to(torch.bool), (mask2 > 0.5).to(torch.bool)
67
+ intersection = torch.sum(mask1 * (mask1 == mask2), dim=[-1, -2]).squeeze()
68
+ union = torch.sum(mask1 + mask2, dim=[-1, -2]).squeeze()
69
+ return (intersection.to(torch.float) / union).mean().item()
70
+
71
+ def batch_apply_bilateral_solver(data,
72
+ masks,
73
+ get_all_cc=True,
74
+ shape=None):
75
+
76
+ cnt_bs = 0
77
+ masks_bs = []
78
+ inputs, init_imgs, gt_labels, img_path = data
79
+
80
+ for id in range(inputs.shape[0]):
81
+ _, bs_mask, use_bs = apply_bilateral_solver(
82
+ mask=masks[id].squeeze().cpu().numpy(),
83
+ img=init_imgs[id],
84
+ img_path=img_path[id],
85
+ im_fullsize=False,
86
+ # Careful shape should be opposed
87
+ shape=(gt_labels.shape[-1], gt_labels.shape[-2]),
88
+ get_all_cc=get_all_cc,
89
+ )
90
+ cnt_bs += use_bs
91
+
92
+ # use the bilateral solver output if IoU > 0.5
93
+ if use_bs:
94
+ if shape is None:
95
+ shape = masks.shape[-2:]
96
+ # Interpolate to downsample the mask back
97
+ bs_ds = F.interpolate(
98
+ torch.Tensor(bs_mask).unsqueeze(0).unsqueeze(0),
99
+ shape, # TODO check here
100
+ mode="bilinear",
101
+ align_corners=False,
102
+ )
103
+ masks_bs.append(bs_ds.bool().cuda().squeeze()[None, :, :])
104
+ else:
105
+ # Use initial mask
106
+ masks_bs.append(masks[id].cuda().squeeze()[None, :, :])
107
+
108
+ return torch.cat(masks_bs).squeeze(), cnt_bs
109
+
110
+
111
+ def apply_bilateral_solver(
112
+ mask,
113
+ img,
114
+ img_path,
115
+ shape,
116
+ im_fullsize=False,
117
+ get_all_cc=False,
118
+ bs_iou_threshold: float = 0.5,
119
+ reshape: bool = True,
120
+ ):
121
+ # Get initial image in the case of using full image
122
+ img_init = None
123
+ if not im_fullsize:
124
+ # Use the image given by dataloader
125
+ shape = (img.shape[-1], img.shape[-2])
126
+ t = T.ToPILImage()
127
+ img_init = t(img)
128
+
129
+ if reshape:
130
+ # Resize predictions to image size
131
+ resized_mask = cv2.resize(mask, shape)
132
+ sel_obj_mask = resized_mask
133
+ else:
134
+ resized_mask = mask
135
+ sel_obj_mask = mask
136
+
137
+ # Apply bilinear solver
138
+ _, binary_solver = bilateral_solver_output(
139
+ img_path,
140
+ resized_mask,
141
+ img=img_init,
142
+ sigma_spatial=16,
143
+ sigma_luma=16,
144
+ sigma_chroma=8,
145
+ get_all_cc=get_all_cc,
146
+ )
147
+
148
+ mask1 = torch.from_numpy(resized_mask).cuda()
149
+ mask2 = torch.from_numpy(binary_solver).cuda().float()
150
+
151
+ use_bs = 0
152
+ # If enough overlap, use BS output
153
+ if IoU(mask1, mask2) > bs_iou_threshold:
154
+ sel_obj_mask = binary_solver.astype(float)
155
+ use_bs = 1
156
+
157
+ return resized_mask, sel_obj_mask, use_bs
158
+
159
+ def get_bbox_from_segmentation_labels(
160
+ segmenter_predictions: torch.Tensor,
161
+ initial_image_size: torch.Size,
162
+ scales: List[int],
163
+ ) -> np.array:
164
+ """
165
+ Find the largest connected component in foreground, extract its bounding box
166
+ """
167
+ objects, num_objects = scipy.ndimage.label(segmenter_predictions)
168
+
169
+ # find biggest connected component
170
+ all_foreground_labels = objects.flatten()[objects.flatten() != 0]
171
+ most_frequent_label = np.bincount(all_foreground_labels).argmax()
172
+ mask = np.where(objects == most_frequent_label)
173
+ # Add +1 because excluded max
174
+ ymin, ymax = min(mask[0]), max(mask[0]) + 1
175
+ xmin, xmax = min(mask[1]), max(mask[1]) + 1
176
+
177
+ if initial_image_size == segmenter_predictions.shape:
178
+ # Masks are already upsampled
179
+ pred = [xmin, ymin, xmax, ymax]
180
+ else:
181
+ # Rescale to image size
182
+ r_xmin, r_xmax = scales[1] * xmin, scales[1] * xmax
183
+ r_ymin, r_ymax = scales[0] * ymin, scales[0] * ymax
184
+ pred = [r_xmin, r_ymin, r_xmax, r_ymax]
185
+
186
+ # Check not out of image size (used when padding)
187
+ if initial_image_size:
188
+ pred[2] = min(pred[2], initial_image_size[1])
189
+ pred[3] = min(pred[3], initial_image_size[0])
190
+
191
+ return np.asarray(pred)
192
+
193
+
194
+ def bbox_iou(
195
+ box1: np.array,
196
+ box2: np.array,
197
+ x1y1x2y2: bool = True,
198
+ GIoU: bool = False,
199
+ DIoU: bool = False,
200
+ CIoU: bool = False,
201
+ eps: float = 1e-7,
202
+ ):
203
+ # https://github.com/ultralytics/yolov5/blob/develop/utils/general.py
204
+ # Returns the IoU of box1 to box2. box1 is 4, box2 is nx4
205
+ box2 = box2.T
206
+
207
+ # Get the coordinates of bounding boxes
208
+ if x1y1x2y2: # x1, y1, x2, y2 = box1
209
+ b1_x1, b1_y1, b1_x2, b1_y2 = box1[0], box1[1], box1[2], box1[3]
210
+ b2_x1, b2_y1, b2_x2, b2_y2 = box2[0], box2[1], box2[2], box2[3]
211
+ else: # transform from xywh to xyxy
212
+ b1_x1, b1_x2 = box1[0] - box1[2] / 2, box1[0] + box1[2] / 2
213
+ b1_y1, b1_y2 = box1[1] - box1[3] / 2, box1[1] + box1[3] / 2
214
+ b2_x1, b2_x2 = box2[0] - box2[2] / 2, box2[0] + box2[2] / 2
215
+ b2_y1, b2_y2 = box2[1] - box2[3] / 2, box2[1] + box2[3] / 2
216
+
217
+ # Intersection area
218
+ inter = (torch.min(b1_x2, b2_x2) - torch.max(b1_x1, b2_x1)).clamp(0) * (
219
+ torch.min(b1_y2, b2_y2) - torch.max(b1_y1, b2_y1)
220
+ ).clamp(0)
221
+
222
+ # Union Area
223
+ w1, h1 = b1_x2 - b1_x1, b1_y2 - b1_y1 + eps
224
+ w2, h2 = b2_x2 - b2_x1, b2_y2 - b2_y1 + eps
225
+ union = w1 * h1 + w2 * h2 - inter + eps
226
+
227
+ iou = inter / union
228
+ if GIoU or DIoU or CIoU:
229
+ cw = torch.max(b1_x2, b2_x2) - torch.min(
230
+ b1_x1, b2_x1
231
+ ) # convex (smallest enclosing box) width
232
+ ch = torch.max(b1_y2, b2_y2) - torch.min(b1_y1, b2_y1) # convex height
233
+ if CIoU or DIoU: # Distance or Complete IoU https://arxiv.org/abs/1911.08287v1
234
+ c2 = cw**2 + ch**2 + eps # convex diagonal squared
235
+ rho2 = (
236
+ (b2_x1 + b2_x2 - b1_x1 - b1_x2) ** 2
237
+ + (b2_y1 + b2_y2 - b1_y1 - b1_y2) ** 2
238
+ ) / 4 # center distance squared
239
+ if DIoU:
240
+ return iou - rho2 / c2 # DIoU
241
+ elif (
242
+ CIoU
243
+ ): # https://github.com/Zzh-tju/DIoU-SSD-pytorch/blob/master/utils/box/box_utils.py#L47
244
+ v = (4 / math.pi**2) * torch.pow(
245
+ torch.atan(w2 / h2) - torch.atan(w1 / h1), 2
246
+ )
247
+ with torch.no_grad():
248
+ alpha = v / (v - iou + (1 + eps))
249
+ return iou - (rho2 / c2 + v * alpha) # CIoU
250
+ else: # GIoU https://arxiv.org/pdf/1902.09630.pdf
251
+ c_area = cw * ch + eps # convex area
252
+ return iou - (c_area - union) / c_area # GIoU
253
+ else:
254
+ return iou # IoU
model.py ADDED
@@ -0,0 +1,243 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2022 - Valeo Comfort and Driving Assistance - Oriane Siméoni @ valeo.ai
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+
16
+ import os
17
+ import torch
18
+ import torch.nn as nn
19
+ import dino.vision_transformer as vits
20
+
21
+ from bkg_seg import compute_img_bkg_seg
22
+ from misc import batch_apply_bilateral_solver
23
+
24
+ class FoundModel(nn.Module):
25
+ def __init__(
26
+ self,
27
+ vit_model="dino",
28
+ vit_arch="vit_small",
29
+ vit_patch_size=8,
30
+ enc_type_feats="k",
31
+ bkg_type_feats="k",
32
+ bkg_th=0.3
33
+ ):
34
+
35
+ super(FoundModel, self).__init__()
36
+
37
+ # ----------------------
38
+ # Encoder
39
+ self.vit_encoder, self.initial_dim, self.hook_features = get_vit_encoder(
40
+ vit_arch, vit_model, vit_patch_size, enc_type_feats
41
+ )
42
+ self.vit_patch_size = vit_patch_size
43
+ self.enc_type_feats = enc_type_feats
44
+
45
+ # ----------------------
46
+ # Background Segmentation
47
+ self.bkg_type_feats = bkg_type_feats
48
+ self.bkg_th = bkg_th
49
+
50
+ # ----------------------
51
+ # Define the simple decoder
52
+ self.previous_dim = self.initial_dim
53
+ self.decoder = nn.Conv2d(self.previous_dim, 1, (1, 1))
54
+
55
+ def forward_step(self, batch, decoder=None, for_eval=False):
56
+
57
+ # Make the image divisible by the patch size
58
+ if for_eval:
59
+ batch = self.make_input_divisible(batch)
60
+ _w, _h = batch.shape[-2:]
61
+ _h, _w = _h // self.vit_patch_size, _w // self.vit_patch_size
62
+ else:
63
+ # Cropping used during training, could be changed to improve
64
+ w, h = (
65
+ batch.shape[-2] - batch.shape[-2] % self.vit_patch_size,
66
+ batch.shape[-1] - batch.shape[-1] % self.vit_patch_size,
67
+ )
68
+ batch = batch[:, :, :w, :h]
69
+
70
+ w_featmap = batch.shape[-2] // self.vit_patch_size
71
+ h_featmap = batch.shape[-1] // self.vit_patch_size
72
+
73
+ # Forward pass
74
+ with torch.no_grad():
75
+ # Encoder forward pass
76
+ att = self.vit_encoder.get_last_selfattention(batch)
77
+
78
+ # Get decoder features
79
+ feats = self.extract_feats(dims=att.shape, type_feats=self.enc_type_feats)
80
+ feats = feats[:, 1:, :, :].reshape(att.shape[0], w_featmap, h_featmap, -1)
81
+ feats = feats.permute(0, 3, 1, 2)
82
+
83
+ # Apply decoder
84
+ if decoder is None:
85
+ decoder = self.decoder
86
+ preds = decoder(feats)
87
+
88
+ # return preds_masked
89
+ return preds, feats, (w_featmap, h_featmap), att
90
+
91
+ def make_input_divisible(self, x: torch.Tensor) -> torch.Tensor:
92
+ # From selfmask
93
+ """Pad some pixels to make the input size divisible by the patch size."""
94
+ B, _, H_0, W_0 = x.shape
95
+ pad_w = (self.vit_patch_size - W_0 % self.vit_patch_size) % self.vit_patch_size
96
+ pad_h = (self.vit_patch_size - H_0 % self.vit_patch_size) % self.vit_patch_size
97
+
98
+ x = nn.functional.pad(x, (0, pad_w, 0, pad_h), value=0)
99
+ return x
100
+
101
+ def compute_background_batch(
102
+ self,
103
+ att,
104
+ shape_f,
105
+ # mlp_feats = None,
106
+ ):
107
+
108
+ w_f, h_f = shape_f
109
+
110
+ # Dimensions
111
+ nb_im = att.shape[0] # Batch size
112
+ nh = att.shape[1] # Number of heads
113
+ nb_tokens = att.shape[2] # Number of tokens
114
+
115
+ # Get decoder features
116
+ feats = self.extract_feats(dims=att.shape,
117
+ # mlp_feats = mlp_feats,
118
+ type_feats=self.bkg_type_feats
119
+ )
120
+ feats = feats.reshape(nb_im, nb_tokens, -1)
121
+
122
+ bkg_mask = compute_img_bkg_seg(
123
+ att,
124
+ feats,
125
+ (w_f,h_f),
126
+ th_bkg=self.bkg_th,
127
+ dim=int(self.initial_dim / nh),
128
+ )
129
+
130
+ return bkg_mask
131
+
132
+
133
+ def get_bkg_pseudo_labels_batch(
134
+ self,
135
+ att,
136
+ shape_f,
137
+ data,
138
+ use_bilateral_solver = True,
139
+ shape=None,
140
+ ):
141
+
142
+ bkg_mask_pred = self.compute_background_batch(
143
+ att, shape_f
144
+ )
145
+ # Transform bkg detection to foreground detection
146
+ # Object mask is the inverse of the bkg mask
147
+ obj_mask = (~bkg_mask_pred.bool()).float()
148
+
149
+ if use_bilateral_solver:
150
+ pseudo_labels, cnt_bs = batch_apply_bilateral_solver(data, obj_mask, shape)
151
+ return pseudo_labels, cnt_bs
152
+ else:
153
+ return obj_mask, 0
154
+
155
+ @torch.no_grad()
156
+ def decoder_load_weights(self, weights_path):
157
+ print(f"Loading model from weights {weights_path}.")
158
+ # Load states
159
+ state_dict = torch.load(weights_path)
160
+
161
+ # Decoder
162
+ self.decoder.load_state_dict(state_dict["decoder"])
163
+ self.decoder.eval()
164
+ self.decoder.to("cuda")
165
+
166
+
167
+ @torch.no_grad()
168
+ def decoder_save_weights(self, save_dir, n_iter):
169
+ state_dict = {}
170
+ state_dict["decoder"] = self.decoder.state_dict()
171
+ fname = os.path.join(
172
+ save_dir, f"decoder_weights_niter{n_iter}.pt"
173
+ )
174
+ torch.save(state_dict, fname)
175
+ print(f"\n----"
176
+ f"\nModel saved at {fname}"
177
+ )
178
+
179
+ @torch.no_grad()
180
+ def extract_feats(self, dims, type_feats="k"):
181
+
182
+ nb_im, nh, nb_tokens, _ = dims
183
+ qkv = (
184
+ self.hook_features["qkv"]
185
+ .reshape(
186
+ nb_im, nb_tokens, 3, nh, -1 // nh
187
+ ) # 3 corresponding to |qkv|
188
+ .permute(2, 0, 3, 1, 4)
189
+ )
190
+
191
+ q, k, v = qkv[0], qkv[1], qkv[2]
192
+
193
+ if type_feats == "q":
194
+ return q.transpose(1, 2).float()
195
+ elif type_feats == "k":
196
+ return k.transpose(1, 2).float()
197
+ elif type_feats == "v":
198
+ return v.transpose(1, 2).float()
199
+ else:
200
+ raise ValueError("Unknown features")
201
+
202
+
203
+ def get_vit_encoder(vit_arch, vit_model, vit_patch_size, enc_type_feats):
204
+ if vit_arch == "vit_small" and vit_patch_size == 16:
205
+ url = "dino_deitsmall16_pretrain/dino_deitsmall16_pretrain.pth"
206
+ initial_dim = 384
207
+ elif vit_arch == "vit_small" and vit_patch_size == 8:
208
+ url = "dino_deitsmall8_300ep_pretrain/dino_deitsmall8_300ep_pretrain.pth"
209
+ initial_dim = 384
210
+ elif vit_arch == "vit_base" and vit_patch_size == 16:
211
+ if vit_model == "clip":
212
+ url = "5806e77cd80f8b59890b7e101eabd078d9fb84e6937f9e85e4ecb61988df416f/ViT-B-16.pt"
213
+ elif vit_model == "dino":
214
+ url = "dino_vitbase16_pretrain/dino_vitbase16_pretrain.pth"
215
+ initial_dim = 768
216
+ elif vit_arch == "vit_base" and vit_patch_size == 8:
217
+ url = "dino_vitbase8_pretrain/dino_vitbase8_pretrain.pth"
218
+ initial_dim = 768
219
+
220
+ if vit_model == "dino":
221
+ vit_encoder = vits.__dict__[vit_arch](patch_size=vit_patch_size, num_classes=0)
222
+ # TODO change if want to have last layer not unfrozen
223
+ for p in vit_encoder.parameters():
224
+ p.requires_grad = False
225
+ vit_encoder.eval().cuda() # mode eval
226
+ state_dict = torch.hub.load_state_dict_from_url(
227
+ url="https://dl.fbaipublicfiles.com/dino/" + url
228
+ )
229
+ vit_encoder.load_state_dict(state_dict, strict=True)
230
+
231
+ hook_features = {}
232
+ if enc_type_feats in ["k", "q", "v", "qkv", "mlp"]:
233
+ # Define the hook
234
+ def hook_fn_forward_qkv(module, input, output):
235
+ hook_features["qkv"] = output
236
+
237
+ vit_encoder._modules["blocks"][-1]._modules["attn"]._modules[
238
+ "qkv"
239
+ ].register_forward_hook(hook_fn_forward_qkv)
240
+ else:
241
+ raise ValueError("Not implemented.")
242
+
243
+ return vit_encoder, initial_dim, hook_features
requirements.txt ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ pyyaml
2
+ matplotlib==3.5.2
3
+ numpy==1.21.4
4
+ opencv-python==4.5.5.64
5
+ opencv-python-headless==4.5.5.64
6
+ scipy==1.7.3
7
+ tensorboard
8
+ tqdm==4.64.0
9
+ pycocotools==2.0.4
10
+ Pillow==9.1.1