FOUND - second
Browse files- LICENSE +201 -0
- __init__.py +3 -0
- app.py +84 -6
- bilateral_solver.py +214 -0
- bkg_seg.py +84 -0
- configs/found_DUTS-TR.yaml +34 -0
- data/examples/VOC_000030.jpg +0 -0
- data/weights/decoder_weights.pt +0 -0
- datasets/VOC.py +80 -0
- datasets/__init__.py +0 -0
- datasets/augmentations.py +68 -0
- datasets/datasets.py +409 -0
- datasets/geometric_transforms.py +160 -0
- datasets/uod_datasets.py +384 -0
- datasets/utils.py +44 -0
- evaluation/__init__.py +0 -0
- evaluation/metrics/__init__.py +0 -0
- evaluation/metrics/average_meter.py +21 -0
- evaluation/metrics/f_measure.py +111 -0
- evaluation/metrics/iou.py +37 -0
- evaluation/metrics/mae.py +14 -0
- evaluation/metrics/pixel_acc.py +21 -0
- evaluation/metrics/s_measure.py +126 -0
- evaluation/saliency.py +290 -0
- evaluation/uod.py +118 -0
- main_found_evaluate.py +122 -0
- main_visualize.py +99 -0
- misc.py +254 -0
- model.py +243 -0
- requirements.txt +10 -0
LICENSE
ADDED
@@ -0,0 +1,201 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
Apache License
|
2 |
+
Version 2.0, January 2004
|
3 |
+
http://www.apache.org/licenses/
|
4 |
+
|
5 |
+
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
|
6 |
+
|
7 |
+
1. Definitions.
|
8 |
+
|
9 |
+
"License" shall mean the terms and conditions for use, reproduction,
|
10 |
+
and distribution as defined by Sections 1 through 9 of this document.
|
11 |
+
|
12 |
+
"Licensor" shall mean the copyright owner or entity authorized by
|
13 |
+
the copyright owner that is granting the License.
|
14 |
+
|
15 |
+
"Legal Entity" shall mean the union of the acting entity and all
|
16 |
+
other entities that control, are controlled by, or are under common
|
17 |
+
control with that entity. For the purposes of this definition,
|
18 |
+
"control" means (i) the power, direct or indirect, to cause the
|
19 |
+
direction or management of such entity, whether by contract or
|
20 |
+
otherwise, or (ii) ownership of fifty percent (50%) or more of the
|
21 |
+
outstanding shares, or (iii) beneficial ownership of such entity.
|
22 |
+
|
23 |
+
"You" (or "Your") shall mean an individual or Legal Entity
|
24 |
+
exercising permissions granted by this License.
|
25 |
+
|
26 |
+
"Source" form shall mean the preferred form for making modifications,
|
27 |
+
including but not limited to software source code, documentation
|
28 |
+
source, and configuration files.
|
29 |
+
|
30 |
+
"Object" form shall mean any form resulting from mechanical
|
31 |
+
transformation or translation of a Source form, including but
|
32 |
+
not limited to compiled object code, generated documentation,
|
33 |
+
and conversions to other media types.
|
34 |
+
|
35 |
+
"Work" shall mean the work of authorship, whether in Source or
|
36 |
+
Object form, made available under the License, as indicated by a
|
37 |
+
copyright notice that is included in or attached to the work
|
38 |
+
(an example is provided in the Appendix below).
|
39 |
+
|
40 |
+
"Derivative Works" shall mean any work, whether in Source or Object
|
41 |
+
form, that is based on (or derived from) the Work and for which the
|
42 |
+
editorial revisions, annotations, elaborations, or other modifications
|
43 |
+
represent, as a whole, an original work of authorship. For the purposes
|
44 |
+
of this License, Derivative Works shall not include works that remain
|
45 |
+
separable from, or merely link (or bind by name) to the interfaces of,
|
46 |
+
the Work and Derivative Works thereof.
|
47 |
+
|
48 |
+
"Contribution" shall mean any work of authorship, including
|
49 |
+
the original version of the Work and any modifications or additions
|
50 |
+
to that Work or Derivative Works thereof, that is intentionally
|
51 |
+
submitted to Licensor for inclusion in the Work by the copyright owner
|
52 |
+
or by an individual or Legal Entity authorized to submit on behalf of
|
53 |
+
the copyright owner. For the purposes of this definition, "submitted"
|
54 |
+
means any form of electronic, verbal, or written communication sent
|
55 |
+
to the Licensor or its representatives, including but not limited to
|
56 |
+
communication on electronic mailing lists, source code control systems,
|
57 |
+
and issue tracking systems that are managed by, or on behalf of, the
|
58 |
+
Licensor for the purpose of discussing and improving the Work, but
|
59 |
+
excluding communication that is conspicuously marked or otherwise
|
60 |
+
designated in writing by the copyright owner as "Not a Contribution."
|
61 |
+
|
62 |
+
"Contributor" shall mean Licensor and any individual or Legal Entity
|
63 |
+
on behalf of whom a Contribution has been received by Licensor and
|
64 |
+
subsequently incorporated within the Work.
|
65 |
+
|
66 |
+
2. Grant of Copyright License. Subject to the terms and conditions of
|
67 |
+
this License, each Contributor hereby grants to You a perpetual,
|
68 |
+
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
69 |
+
copyright license to reproduce, prepare Derivative Works of,
|
70 |
+
publicly display, publicly perform, sublicense, and distribute the
|
71 |
+
Work and such Derivative Works in Source or Object form.
|
72 |
+
|
73 |
+
3. Grant of Patent License. Subject to the terms and conditions of
|
74 |
+
this License, each Contributor hereby grants to You a perpetual,
|
75 |
+
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
76 |
+
(except as stated in this section) patent license to make, have made,
|
77 |
+
use, offer to sell, sell, import, and otherwise transfer the Work,
|
78 |
+
where such license applies only to those patent claims licensable
|
79 |
+
by such Contributor that are necessarily infringed by their
|
80 |
+
Contribution(s) alone or by combination of their Contribution(s)
|
81 |
+
with the Work to which such Contribution(s) was submitted. If You
|
82 |
+
institute patent litigation against any entity (including a
|
83 |
+
cross-claim or counterclaim in a lawsuit) alleging that the Work
|
84 |
+
or a Contribution incorporated within the Work constitutes direct
|
85 |
+
or contributory patent infringement, then any patent licenses
|
86 |
+
granted to You under this License for that Work shall terminate
|
87 |
+
as of the date such litigation is filed.
|
88 |
+
|
89 |
+
4. Redistribution. You may reproduce and distribute copies of the
|
90 |
+
Work or Derivative Works thereof in any medium, with or without
|
91 |
+
modifications, and in Source or Object form, provided that You
|
92 |
+
meet the following conditions:
|
93 |
+
|
94 |
+
(a) You must give any other recipients of the Work or
|
95 |
+
Derivative Works a copy of this License; and
|
96 |
+
|
97 |
+
(b) You must cause any modified files to carry prominent notices
|
98 |
+
stating that You changed the files; and
|
99 |
+
|
100 |
+
(c) You must retain, in the Source form of any Derivative Works
|
101 |
+
that You distribute, all copyright, patent, trademark, and
|
102 |
+
attribution notices from the Source form of the Work,
|
103 |
+
excluding those notices that do not pertain to any part of
|
104 |
+
the Derivative Works; and
|
105 |
+
|
106 |
+
(d) If the Work includes a "NOTICE" text file as part of its
|
107 |
+
distribution, then any Derivative Works that You distribute must
|
108 |
+
include a readable copy of the attribution notices contained
|
109 |
+
within such NOTICE file, excluding those notices that do not
|
110 |
+
pertain to any part of the Derivative Works, in at least one
|
111 |
+
of the following places: within a NOTICE text file distributed
|
112 |
+
as part of the Derivative Works; within the Source form or
|
113 |
+
documentation, if provided along with the Derivative Works; or,
|
114 |
+
within a display generated by the Derivative Works, if and
|
115 |
+
wherever such third-party notices normally appear. The contents
|
116 |
+
of the NOTICE file are for informational purposes only and
|
117 |
+
do not modify the License. You may add Your own attribution
|
118 |
+
notices within Derivative Works that You distribute, alongside
|
119 |
+
or as an addendum to the NOTICE text from the Work, provided
|
120 |
+
that such additional attribution notices cannot be construed
|
121 |
+
as modifying the License.
|
122 |
+
|
123 |
+
You may add Your own copyright statement to Your modifications and
|
124 |
+
may provide additional or different license terms and conditions
|
125 |
+
for use, reproduction, or distribution of Your modifications, or
|
126 |
+
for any such Derivative Works as a whole, provided Your use,
|
127 |
+
reproduction, and distribution of the Work otherwise complies with
|
128 |
+
the conditions stated in this License.
|
129 |
+
|
130 |
+
5. Submission of Contributions. Unless You explicitly state otherwise,
|
131 |
+
any Contribution intentionally submitted for inclusion in the Work
|
132 |
+
by You to the Licensor shall be under the terms and conditions of
|
133 |
+
this License, without any additional terms or conditions.
|
134 |
+
Notwithstanding the above, nothing herein shall supersede or modify
|
135 |
+
the terms of any separate license agreement you may have executed
|
136 |
+
with Licensor regarding such Contributions.
|
137 |
+
|
138 |
+
6. Trademarks. This License does not grant permission to use the trade
|
139 |
+
names, trademarks, service marks, or product names of the Licensor,
|
140 |
+
except as required for reasonable and customary use in describing the
|
141 |
+
origin of the Work and reproducing the content of the NOTICE file.
|
142 |
+
|
143 |
+
7. Disclaimer of Warranty. Unless required by applicable law or
|
144 |
+
agreed to in writing, Licensor provides the Work (and each
|
145 |
+
Contributor provides its Contributions) on an "AS IS" BASIS,
|
146 |
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
|
147 |
+
implied, including, without limitation, any warranties or conditions
|
148 |
+
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
|
149 |
+
PARTICULAR PURPOSE. You are solely responsible for determining the
|
150 |
+
appropriateness of using or redistributing the Work and assume any
|
151 |
+
risks associated with Your exercise of permissions under this License.
|
152 |
+
|
153 |
+
8. Limitation of Liability. In no event and under no legal theory,
|
154 |
+
whether in tort (including negligence), contract, or otherwise,
|
155 |
+
unless required by applicable law (such as deliberate and grossly
|
156 |
+
negligent acts) or agreed to in writing, shall any Contributor be
|
157 |
+
liable to You for damages, including any direct, indirect, special,
|
158 |
+
incidental, or consequential damages of any character arising as a
|
159 |
+
result of this License or out of the use or inability to use the
|
160 |
+
Work (including but not limited to damages for loss of goodwill,
|
161 |
+
work stoppage, computer failure or malfunction, or any and all
|
162 |
+
other commercial damages or losses), even if such Contributor
|
163 |
+
has been advised of the possibility of such damages.
|
164 |
+
|
165 |
+
9. Accepting Warranty or Additional Liability. While redistributing
|
166 |
+
the Work or Derivative Works thereof, You may choose to offer,
|
167 |
+
and charge a fee for, acceptance of support, warranty, indemnity,
|
168 |
+
or other liability obligations and/or rights consistent with this
|
169 |
+
License. However, in accepting such obligations, You may act only
|
170 |
+
on Your own behalf and on Your sole responsibility, not on behalf
|
171 |
+
of any other Contributor, and only if You agree to indemnify,
|
172 |
+
defend, and hold each Contributor harmless for any liability
|
173 |
+
incurred by, or claims asserted against, such Contributor by reason
|
174 |
+
of your accepting any such warranty or additional liability.
|
175 |
+
|
176 |
+
END OF TERMS AND CONDITIONS
|
177 |
+
|
178 |
+
APPENDIX: How to apply the Apache License to your work.
|
179 |
+
|
180 |
+
To apply the Apache License to your work, attach the following
|
181 |
+
boilerplate notice, with the fields enclosed by brackets "[]"
|
182 |
+
replaced with your own identifying information. (Don't include
|
183 |
+
the brackets!) The text should be enclosed in the appropriate
|
184 |
+
comment syntax for the file format. We also recommend that a
|
185 |
+
file or class name and description of purpose be included on the
|
186 |
+
same "printed page" as the copyright notice for easier
|
187 |
+
identification within third-party archives.
|
188 |
+
|
189 |
+
Copyright [yyyy] [name of copyright owner]
|
190 |
+
|
191 |
+
Licensed under the Apache License, Version 2.0 (the "License");
|
192 |
+
you may not use this file except in compliance with the License.
|
193 |
+
You may obtain a copy of the License at
|
194 |
+
|
195 |
+
http://www.apache.org/licenses/LICENSE-2.0
|
196 |
+
|
197 |
+
Unless required by applicable law or agreed to in writing, software
|
198 |
+
distributed under the License is distributed on an "AS IS" BASIS,
|
199 |
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
200 |
+
See the License for the specific language governing permissions and
|
201 |
+
limitations under the License.
|
__init__.py
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
import sys
|
2 |
+
from os.path import dirname, join
|
3 |
+
sys.path.insert(0, join(dirname(__file__), '.'))
|
app.py
CHANGED
@@ -1,16 +1,94 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
import gradio as gr
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
2 |
|
3 |
title = 'FOUND'
|
4 |
description = 'Gradio Demo accompanying paper "Unsupervised Object Localization: Observing the Background to Discover Objects"\n \
|
5 |
The app is running CPU-only, times are therefore .\n'
|
6 |
-
article = """<
|
|
|
7 |
"""
|
|
|
|
|
8 |
|
9 |
-
|
10 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
11 |
|
12 |
-
iface
|
13 |
-
article=article, inputs="text", outputs="text")
|
14 |
-
iface.launch()
|
15 |
|
16 |
|
|
|
1 |
+
import os
|
2 |
+
import torch
|
3 |
+
import argparse
|
4 |
+
import torch.nn as nn
|
5 |
+
import torch.nn.functional as F
|
6 |
+
import matplotlib.pyplot as plt
|
7 |
+
|
8 |
+
from PIL import Image
|
9 |
+
from model import FoundModel
|
10 |
+
from misc import load_config
|
11 |
+
from torchvision import transforms as T
|
12 |
+
|
13 |
+
|
14 |
import gradio as gr
|
15 |
+
|
16 |
+
NORMALIZE = T.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))
|
17 |
+
CACHE = True
|
18 |
+
|
19 |
+
def blend_images(bg, fg, alpha=0.5):
|
20 |
+
fg = fg.convert('RGBA')
|
21 |
+
bg = bg.convert('RGBA')
|
22 |
+
blended = Image.blend(bg, fg, alpha=alpha)
|
23 |
+
|
24 |
+
return blended
|
25 |
+
|
26 |
+
|
27 |
+
def predict(img_input):
|
28 |
+
|
29 |
+
config = "configs/found_DUTS-TR.yaml"
|
30 |
+
model_weights = "data/weights/decoder_weights.pt"
|
31 |
+
|
32 |
+
# Configuration
|
33 |
+
config = load_config(config)
|
34 |
+
|
35 |
+
# ------------------------------------
|
36 |
+
# Load the model
|
37 |
+
model = FoundModel(vit_model=config.model["pre_training"],
|
38 |
+
vit_arch=config.model["arch"],
|
39 |
+
vit_patch_size=config.model["patch_size"],
|
40 |
+
enc_type_feats=config.found["feats"],
|
41 |
+
bkg_type_feats=config.found["feats"],
|
42 |
+
bkg_th=config.found["bkg_th"])
|
43 |
+
# Load weights
|
44 |
+
model.decoder_load_weights(model_weights)
|
45 |
+
model.eval()
|
46 |
+
print(f"Model {model_weights} loaded correctly.")
|
47 |
+
|
48 |
+
# Load the image
|
49 |
+
img_pil = Image.open(img_input)
|
50 |
+
img = img_pil.convert("RGB")
|
51 |
+
|
52 |
+
t = T.Compose([T.ToTensor(), NORMALIZE])
|
53 |
+
img_t = t(img)[None,:,:,:]
|
54 |
+
inputs = img_t.to("cuda")
|
55 |
+
|
56 |
+
# Forward step
|
57 |
+
with torch.no_grad():
|
58 |
+
preds, _, _, _ = model.forward_step(inputs, for_eval=True)
|
59 |
+
|
60 |
+
# Apply FOUND
|
61 |
+
sigmoid = nn.Sigmoid()
|
62 |
+
h, w = img_t.shape[-2:]
|
63 |
+
preds_up = F.interpolate(
|
64 |
+
preds, scale_factor=model.vit_patch_size, mode="bilinear", align_corners=False
|
65 |
+
)[..., :h, :w]
|
66 |
+
preds_up = (
|
67 |
+
(sigmoid(preds_up.detach()) > 0.5).squeeze(0).float()
|
68 |
+
)
|
69 |
+
|
70 |
+
return blend_images(img_pil, preds_up)
|
71 |
+
|
72 |
|
73 |
title = 'FOUND'
|
74 |
description = 'Gradio Demo accompanying paper "Unsupervised Object Localization: Observing the Background to Discover Objects"\n \
|
75 |
The app is running CPU-only, times are therefore .\n'
|
76 |
+
article = """<h2 align="center">Unsupervised Object Localization: Observing the Background to Discover Objects </h2>
|
77 |
+
<h1 align="center"> FOUND </h1>
|
78 |
"""
|
79 |
+
examples = ["data/examples/VOC_000030.jpg"]
|
80 |
+
|
81 |
|
82 |
+
iface = gr.Interface(fn=predict,
|
83 |
+
title=title,
|
84 |
+
description=description,
|
85 |
+
article=article,
|
86 |
+
inputs=gr.Image(type='filepath'),
|
87 |
+
outputs=gr.Image(label="Object localization", type="pil"),
|
88 |
+
examples=examples,
|
89 |
+
cache_examples=CACHE
|
90 |
+
)
|
91 |
|
92 |
+
iface.launch(show_error=True, enable_queue=True, inline=True)
|
|
|
|
|
93 |
|
94 |
|
bilateral_solver.py
ADDED
@@ -0,0 +1,214 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Code adapted from TokenCut: https://github.com/YangtaoWANG95/TokenCut
|
3 |
+
"""
|
4 |
+
|
5 |
+
import PIL.Image as Image
|
6 |
+
import numpy as np
|
7 |
+
from scipy import ndimage
|
8 |
+
from scipy.sparse import diags, csr_matrix
|
9 |
+
from scipy.sparse.linalg import cg
|
10 |
+
|
11 |
+
RGB_TO_YUV = np.array(
|
12 |
+
[[0.299, 0.587, 0.114], [-0.168736, -0.331264, 0.5], [0.5, -0.418688, -0.081312]]
|
13 |
+
)
|
14 |
+
YUV_TO_RGB = np.array([[1.0, 0.0, 1.402], [1.0, -0.34414, -0.71414], [1.0, 1.772, 0.0]])
|
15 |
+
YUV_OFFSET = np.array([0, 128.0, 128.0]).reshape(1, 1, -1)
|
16 |
+
MAX_VAL = 255.0
|
17 |
+
|
18 |
+
|
19 |
+
def rgb2yuv(im):
|
20 |
+
return np.tensordot(im, RGB_TO_YUV, ([2], [1])) + YUV_OFFSET
|
21 |
+
|
22 |
+
|
23 |
+
def yuv2rgb(im):
|
24 |
+
return np.tensordot(im.astype(float) - YUV_OFFSET, YUV_TO_RGB, ([2], [1]))
|
25 |
+
|
26 |
+
|
27 |
+
def get_valid_idx(valid, candidates):
|
28 |
+
"""Find which values are present in a list and where they are located"""
|
29 |
+
locs = np.searchsorted(valid, candidates)
|
30 |
+
# Handle edge case where the candidate is larger than all valid values
|
31 |
+
locs = np.clip(locs, 0, len(valid) - 1)
|
32 |
+
# Identify which values are actually present
|
33 |
+
valid_idx = np.flatnonzero(valid[locs] == candidates)
|
34 |
+
locs = locs[valid_idx]
|
35 |
+
return valid_idx, locs
|
36 |
+
|
37 |
+
|
38 |
+
class BilateralGrid(object):
|
39 |
+
def __init__(self, im, sigma_spatial=32, sigma_luma=8, sigma_chroma=8):
|
40 |
+
im_yuv = rgb2yuv(im)
|
41 |
+
# Compute 5-dimensional XYLUV bilateral-space coordinates
|
42 |
+
Iy, Ix = np.mgrid[: im.shape[0], : im.shape[1]]
|
43 |
+
x_coords = (Ix / sigma_spatial).astype(int)
|
44 |
+
y_coords = (Iy / sigma_spatial).astype(int)
|
45 |
+
luma_coords = (im_yuv[..., 0] / sigma_luma).astype(int)
|
46 |
+
chroma_coords = (im_yuv[..., 1:] / sigma_chroma).astype(int)
|
47 |
+
coords = np.dstack((x_coords, y_coords, luma_coords, chroma_coords))
|
48 |
+
coords_flat = coords.reshape(-1, coords.shape[-1])
|
49 |
+
self.npixels, self.dim = coords_flat.shape
|
50 |
+
# Hacky "hash vector" for coordinates,
|
51 |
+
# Requires all scaled coordinates be < MAX_VAL
|
52 |
+
self.hash_vec = MAX_VAL ** np.arange(self.dim)
|
53 |
+
# Construct S and B matrix
|
54 |
+
self._compute_factorization(coords_flat)
|
55 |
+
|
56 |
+
def _compute_factorization(self, coords_flat):
|
57 |
+
# Hash each coordinate in grid to a unique value
|
58 |
+
hashed_coords = self._hash_coords(coords_flat)
|
59 |
+
unique_hashes, unique_idx, idx = np.unique(
|
60 |
+
hashed_coords, return_index=True, return_inverse=True
|
61 |
+
)
|
62 |
+
# Identify unique set of vertices
|
63 |
+
unique_coords = coords_flat[unique_idx]
|
64 |
+
self.nvertices = len(unique_coords)
|
65 |
+
# Construct sparse splat matrix that maps from pixels to vertices
|
66 |
+
self.S = csr_matrix((np.ones(self.npixels), (idx, np.arange(self.npixels))))
|
67 |
+
# Construct sparse blur matrices.
|
68 |
+
# Note that these represent [1 0 1] blurs, excluding the central element
|
69 |
+
self.blurs = []
|
70 |
+
for d in range(self.dim):
|
71 |
+
blur = 0.0
|
72 |
+
for offset in (-1, 1):
|
73 |
+
offset_vec = np.zeros((1, self.dim))
|
74 |
+
offset_vec[:, d] = offset
|
75 |
+
neighbor_hash = self._hash_coords(unique_coords + offset_vec)
|
76 |
+
valid_coord, idx = get_valid_idx(unique_hashes, neighbor_hash)
|
77 |
+
blur = blur + csr_matrix(
|
78 |
+
(np.ones((len(valid_coord),)), (valid_coord, idx)),
|
79 |
+
shape=(self.nvertices, self.nvertices),
|
80 |
+
)
|
81 |
+
self.blurs.append(blur)
|
82 |
+
|
83 |
+
def _hash_coords(self, coord):
|
84 |
+
"""Hacky function to turn a coordinate into a unique value"""
|
85 |
+
return np.dot(coord.reshape(-1, self.dim), self.hash_vec)
|
86 |
+
|
87 |
+
def splat(self, x):
|
88 |
+
return self.S.dot(x)
|
89 |
+
|
90 |
+
def slice(self, y):
|
91 |
+
return self.S.T.dot(y)
|
92 |
+
|
93 |
+
def blur(self, x):
|
94 |
+
"""Blur a bilateral-space vector with a 1 2 1 kernel in each dimension"""
|
95 |
+
assert x.shape[0] == self.nvertices
|
96 |
+
out = 2 * self.dim * x
|
97 |
+
for blur in self.blurs:
|
98 |
+
out = out + blur.dot(x)
|
99 |
+
return out
|
100 |
+
|
101 |
+
def filter(self, x):
|
102 |
+
"""Apply bilateral filter to an input x"""
|
103 |
+
return self.slice(self.blur(self.splat(x))) / self.slice(
|
104 |
+
self.blur(self.splat(np.ones_like(x)))
|
105 |
+
)
|
106 |
+
|
107 |
+
|
108 |
+
def bistochastize(grid, maxiter=10):
|
109 |
+
"""Compute diagonal matrices to bistochastize a bilateral grid"""
|
110 |
+
m = grid.splat(np.ones(grid.npixels))
|
111 |
+
n = np.ones(grid.nvertices)
|
112 |
+
for i in range(maxiter):
|
113 |
+
n = np.sqrt(n * m / grid.blur(n))
|
114 |
+
# Correct m to satisfy the assumption of bistochastization regardless
|
115 |
+
# of how many iterations have been run.
|
116 |
+
m = n * grid.blur(n)
|
117 |
+
Dm = diags(m, 0)
|
118 |
+
Dn = diags(n, 0)
|
119 |
+
return Dn, Dm
|
120 |
+
|
121 |
+
|
122 |
+
class BilateralSolver(object):
|
123 |
+
def __init__(self, grid, params):
|
124 |
+
self.grid = grid
|
125 |
+
self.params = params
|
126 |
+
self.Dn, self.Dm = bistochastize(grid)
|
127 |
+
|
128 |
+
def solve(self, x, w):
|
129 |
+
# Check that w is a vector or a nx1 matrix
|
130 |
+
if w.ndim == 2:
|
131 |
+
assert w.shape[1] == 1
|
132 |
+
elif w.dim == 1:
|
133 |
+
w = w.reshape(w.shape[0], 1)
|
134 |
+
A_smooth = self.Dm - self.Dn.dot(self.grid.blur(self.Dn))
|
135 |
+
w_splat = self.grid.splat(w)
|
136 |
+
A_data = diags(w_splat[:, 0], 0)
|
137 |
+
A = self.params["lam"] * A_smooth + A_data
|
138 |
+
xw = x * w
|
139 |
+
b = self.grid.splat(xw)
|
140 |
+
# Use simple Jacobi preconditioner
|
141 |
+
A_diag = np.maximum(A.diagonal(), self.params["A_diag_min"])
|
142 |
+
M = diags(1 / A_diag, 0)
|
143 |
+
# Flat initialization
|
144 |
+
y0 = self.grid.splat(xw) / w_splat
|
145 |
+
yhat = np.empty_like(y0)
|
146 |
+
for d in range(x.shape[-1]):
|
147 |
+
yhat[..., d], info = cg(
|
148 |
+
A,
|
149 |
+
b[..., d],
|
150 |
+
x0=y0[..., d],
|
151 |
+
M=M,
|
152 |
+
maxiter=self.params["cg_maxiter"],
|
153 |
+
tol=self.params["cg_tol"],
|
154 |
+
)
|
155 |
+
xhat = self.grid.slice(yhat)
|
156 |
+
return xhat
|
157 |
+
|
158 |
+
|
159 |
+
def bilateral_solver_output(
|
160 |
+
img_pth,
|
161 |
+
target,
|
162 |
+
img=None,
|
163 |
+
sigma_spatial=24,
|
164 |
+
sigma_luma=4,
|
165 |
+
sigma_chroma=4,
|
166 |
+
get_all_cc=False
|
167 |
+
):
|
168 |
+
if img is None:
|
169 |
+
reference = np.array(Image.open(img_pth).convert("RGB"))
|
170 |
+
else:
|
171 |
+
reference = np.array(img)
|
172 |
+
|
173 |
+
h, w = target.shape
|
174 |
+
confidence = np.ones((h, w)) * 0.999
|
175 |
+
|
176 |
+
grid_params = {
|
177 |
+
"sigma_luma": sigma_luma, # Brightness bandwidth
|
178 |
+
"sigma_chroma": sigma_chroma, # Color bandwidth
|
179 |
+
"sigma_spatial": sigma_spatial, # Spatial bandwidth
|
180 |
+
}
|
181 |
+
|
182 |
+
bs_params = {
|
183 |
+
"lam": 256, # The strength of the smoothness parameter
|
184 |
+
"A_diag_min": 1e-5, # Clamp the diagonal of the A diagonal in the Jacobi preconditioner.
|
185 |
+
"cg_tol": 1e-5, # The tolerance on the convergence in PCG
|
186 |
+
"cg_maxiter": 25, # The number of PCG iterations
|
187 |
+
}
|
188 |
+
|
189 |
+
grid = BilateralGrid(reference, **grid_params)
|
190 |
+
|
191 |
+
t = target.reshape(-1, 1).astype(np.double)
|
192 |
+
c = confidence.reshape(-1, 1).astype(np.double)
|
193 |
+
|
194 |
+
## output solver, which is a soft value
|
195 |
+
output_solver = BilateralSolver(grid, bs_params).solve(t, c).reshape((h, w))
|
196 |
+
|
197 |
+
binary_solver = ndimage.binary_fill_holes(output_solver > 0.5)
|
198 |
+
labeled, nr_objects = ndimage.label(binary_solver)
|
199 |
+
|
200 |
+
nb_pixel = [np.sum(labeled == i) for i in range(nr_objects + 1)]
|
201 |
+
pixel_order = np.argsort(nb_pixel)
|
202 |
+
|
203 |
+
if get_all_cc:
|
204 |
+
# Remove known bakground
|
205 |
+
pixel_descending_order = pixel_order[::-1]
|
206 |
+
# Get all CC expect biggest one, may consider it as background, try and change here
|
207 |
+
binary_solver = (labeled[None,:,:] == pixel_descending_order[1:,None,None]).astype(int).sum(0)
|
208 |
+
else:
|
209 |
+
try:
|
210 |
+
binary_solver = labeled == pixel_order[-2]
|
211 |
+
except:
|
212 |
+
binary_solver = np.ones((h, w), dtype=bool)
|
213 |
+
|
214 |
+
return output_solver, binary_solver
|
bkg_seg.py
ADDED
@@ -0,0 +1,84 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2022 - Valeo Comfort and Driving Assistance - Oriane Siméoni @ valeo.ai
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
|
15 |
+
import torch
|
16 |
+
import torch.nn.functional as F
|
17 |
+
|
18 |
+
from typing import Tuple
|
19 |
+
|
20 |
+
def compute_img_bkg_seg(
|
21 |
+
attentions,
|
22 |
+
feats,
|
23 |
+
featmap_dims,
|
24 |
+
th_bkg,
|
25 |
+
dim=64,
|
26 |
+
epsilon: float = 1e-10,
|
27 |
+
apply_weights: bool = True,
|
28 |
+
) -> Tuple[torch.Tensor, float]:
|
29 |
+
"""
|
30 |
+
inputs
|
31 |
+
- attentions [B, ]
|
32 |
+
"""
|
33 |
+
|
34 |
+
w_featmap, h_featmap = featmap_dims
|
35 |
+
|
36 |
+
nb, nh, _ = attentions.shape[:3]
|
37 |
+
# we keep only the output patch attention
|
38 |
+
att = attentions[:, :, 0, 1:].reshape(nb, nh, -1)
|
39 |
+
att = att.reshape(nb, nh, w_featmap, h_featmap)
|
40 |
+
|
41 |
+
# -----------------------------------------------
|
42 |
+
# Inspired by CroW sparsity channel weighting of each head CroW, Kalantidis etal.
|
43 |
+
threshold = torch.mean(att.reshape(nb, -1), dim=1) # Find threshold per image
|
44 |
+
Q = torch.sum(
|
45 |
+
att.reshape(nb, nh, w_featmap * h_featmap) > threshold[:, None, None], axis=2
|
46 |
+
) / (w_featmap * h_featmap)
|
47 |
+
beta = torch.log(torch.sum(Q + epsilon, dim=1)[:, None] / (Q + epsilon))
|
48 |
+
|
49 |
+
# Weight features based on attention sparsity
|
50 |
+
descs = feats[:,1:,]
|
51 |
+
if apply_weights:
|
52 |
+
descs = (descs.reshape(nb, -1, nh, dim) * beta[:, None, :, None]).reshape(
|
53 |
+
nb, -1, nh * dim
|
54 |
+
)
|
55 |
+
else:
|
56 |
+
descs = (descs.reshape(nb, -1, nh, dim)).reshape(
|
57 |
+
nb, -1, nh * dim
|
58 |
+
)
|
59 |
+
|
60 |
+
# -----------------------------------------------
|
61 |
+
# Compute cosine-similarities
|
62 |
+
descs = F.normalize(descs, dim=-1, p=2)
|
63 |
+
cos_sim = torch.bmm(descs, descs.permute(0, 2, 1))
|
64 |
+
|
65 |
+
# -----------------------------------------------
|
66 |
+
# Find pixel with least amount of attention
|
67 |
+
if apply_weights:
|
68 |
+
att = att.reshape(nb, nh, w_featmap, h_featmap) * beta[:, :, None, None]
|
69 |
+
else:
|
70 |
+
att = att.reshape(nb, nh, w_featmap, h_featmap)
|
71 |
+
id_pixel_ref = torch.argmin(torch.sum(att, axis=1).reshape(nb, -1), dim=-1)
|
72 |
+
|
73 |
+
# -----------------------------------------------
|
74 |
+
# Mask of definitely background pixels: 1 on the background
|
75 |
+
cos_sim = cos_sim.reshape(nb, -1, w_featmap * h_featmap)
|
76 |
+
|
77 |
+
bkg_mask = (
|
78 |
+
cos_sim[torch.arange(cos_sim.size(0)), id_pixel_ref, :].reshape(
|
79 |
+
nb, w_featmap, h_featmap
|
80 |
+
)
|
81 |
+
> th_bkg
|
82 |
+
) # mask to be used to remove background
|
83 |
+
|
84 |
+
return bkg_mask.float()
|
configs/found_DUTS-TR.yaml
ADDED
@@ -0,0 +1,34 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
model:
|
2 |
+
arch: vit_small
|
3 |
+
patch_size: 8
|
4 |
+
pre_training: dino
|
5 |
+
|
6 |
+
found:
|
7 |
+
bkg_th: 0.3
|
8 |
+
feats: k
|
9 |
+
|
10 |
+
training:
|
11 |
+
dataset: DUTS-TR
|
12 |
+
dataset_set: null
|
13 |
+
|
14 |
+
# Hyper params
|
15 |
+
seed: 0
|
16 |
+
max_iter: 500
|
17 |
+
nb_epochs: 3
|
18 |
+
batch_size: 50
|
19 |
+
lr0: 5e-2
|
20 |
+
step_lr_size: 50
|
21 |
+
step_lr_gamma: 0.95
|
22 |
+
w_bs_loss: 1.5
|
23 |
+
stop_bkg_loss: 100
|
24 |
+
|
25 |
+
# Augmentations
|
26 |
+
crop_size: 224
|
27 |
+
scale_range: [0.1, 3.0]
|
28 |
+
photometric_aug: gaussian_blur
|
29 |
+
proba_photometric_aug: 0.5
|
30 |
+
cropping_strategy: random_scale
|
31 |
+
|
32 |
+
evaluation:
|
33 |
+
type: saliency # uod, retrieval
|
34 |
+
datasets: [DUT-OMRON, ECSSD]
|
data/examples/VOC_000030.jpg
ADDED
![]() |
data/weights/decoder_weights.pt
ADDED
Binary file (2.69 kB). View file
|
|
datasets/VOC.py
ADDED
@@ -0,0 +1,80 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
from typing import Optional, Tuple, Union, Dict, List
|
3 |
+
|
4 |
+
import cv2
|
5 |
+
from pycocotools.coco import COCO
|
6 |
+
import numpy as np
|
7 |
+
import torch
|
8 |
+
import torchvision
|
9 |
+
from PIL import Image, PngImagePlugin
|
10 |
+
from torch.utils.data import Dataset
|
11 |
+
from torchvision import transforms as T
|
12 |
+
from torchvision.transforms import ColorJitter, RandomApply, RandomGrayscale
|
13 |
+
from tqdm import tqdm
|
14 |
+
|
15 |
+
VOCDetectionMetadataType = Dict[str, Dict[str, Union[str, Dict[str, str], List[str]]]]
|
16 |
+
|
17 |
+
def get_voc_detection_gt(
|
18 |
+
metadata: VOCDetectionMetadataType, remove_hards: bool = False
|
19 |
+
) -> Tuple[np.array, List[str]]:
|
20 |
+
objects = metadata["annotation"]["object"]
|
21 |
+
nb_obj = len(objects)
|
22 |
+
|
23 |
+
gt_bbxs = []
|
24 |
+
gt_clss = []
|
25 |
+
for object in range(nb_obj):
|
26 |
+
if remove_hards and (
|
27 |
+
objects[object]["truncated"] == "1"
|
28 |
+
or objects[object]["difficult"] == "1"
|
29 |
+
):
|
30 |
+
continue
|
31 |
+
|
32 |
+
gt_cls = objects[object]["name"]
|
33 |
+
gt_clss.append(gt_cls)
|
34 |
+
obj = objects[object]["bndbox"]
|
35 |
+
x1y1x2y2 = [
|
36 |
+
int(obj["xmin"]),
|
37 |
+
int(obj["ymin"]),
|
38 |
+
int(obj["xmax"]),
|
39 |
+
int(obj["ymax"]),
|
40 |
+
]
|
41 |
+
|
42 |
+
# Original annotations are integers in the range [1, W or H]
|
43 |
+
# Assuming they mean 1-based pixel indices (inclusive),
|
44 |
+
# a box with annotation (xmin=1, xmax=W) covers the whole image.
|
45 |
+
# In coordinate space this is represented by (xmin=0, xmax=W)
|
46 |
+
x1y1x2y2[0] -= 1
|
47 |
+
x1y1x2y2[1] -= 1
|
48 |
+
gt_bbxs.append(x1y1x2y2)
|
49 |
+
|
50 |
+
return np.asarray(gt_bbxs), gt_clss
|
51 |
+
|
52 |
+
def create_gt_masks_if_voc(labels: PngImagePlugin.PngImageFile) -> Image.Image:
|
53 |
+
mask = np.array(labels)
|
54 |
+
mask_gt = (mask > 0).astype(float)
|
55 |
+
mask_gt = np.where(mask_gt != 0.0, 255, mask_gt)
|
56 |
+
mask_gt = Image.fromarray(np.uint8(mask_gt))
|
57 |
+
return mask_gt
|
58 |
+
|
59 |
+
def create_VOC_loader(img_dir, dataset_set, evaluation_type):
|
60 |
+
year = img_dir[-4:]
|
61 |
+
download = not os.path.exists(img_dir)
|
62 |
+
if evaluation_type == "uod":
|
63 |
+
loader = torchvision.datasets.VOCDetection(
|
64 |
+
img_dir,
|
65 |
+
year=year,
|
66 |
+
image_set=dataset_set,
|
67 |
+
transform=None,
|
68 |
+
download=download,
|
69 |
+
)
|
70 |
+
elif evaluation_type == "saliency":
|
71 |
+
loader = torchvision.datasets.VOCSegmentation(
|
72 |
+
img_dir,
|
73 |
+
year=year,
|
74 |
+
image_set=dataset_set,
|
75 |
+
transform=None,
|
76 |
+
download=download,
|
77 |
+
)
|
78 |
+
else:
|
79 |
+
raise ValueError(f"Not implemented for {evaluation_type}.")
|
80 |
+
return loader
|
datasets/__init__.py
ADDED
File without changes
|
datasets/augmentations.py
ADDED
@@ -0,0 +1,68 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Code borrowed from SelfMask: https://github.com/NoelShin/selfmask
|
3 |
+
"""
|
4 |
+
|
5 |
+
import numpy as np
|
6 |
+
import torch
|
7 |
+
from PIL import Image
|
8 |
+
from typing import Optional, Tuple, Union
|
9 |
+
from torchvision.transforms import ColorJitter, RandomApply, RandomGrayscale
|
10 |
+
|
11 |
+
from datasets.utils import GaussianBlur
|
12 |
+
from datasets.geometric_transforms import (
|
13 |
+
random_scale,
|
14 |
+
random_crop,
|
15 |
+
random_hflip,
|
16 |
+
)
|
17 |
+
|
18 |
+
def geometric_augmentations(
|
19 |
+
image: Image.Image,
|
20 |
+
random_scale_range: Optional[Tuple[float, float]] = None,
|
21 |
+
random_crop_size: Optional[int] = None,
|
22 |
+
random_hflip_p: Optional[float] = None,
|
23 |
+
mask: Optional[Union[Image.Image, np.ndarray, torch.Tensor]] = None,
|
24 |
+
ignore_index: Optional[int] = None,
|
25 |
+
) -> Tuple[Image.Image, torch.Tensor]:
|
26 |
+
"""Note. image and mask are assumed to be of base size, thus share a spatial shape."""
|
27 |
+
if random_scale_range is not None:
|
28 |
+
image, mask = random_scale(
|
29 |
+
image=image, random_scale_range=random_scale_range, mask=mask
|
30 |
+
)
|
31 |
+
|
32 |
+
if random_crop_size is not None:
|
33 |
+
crop_size = (random_crop_size, random_crop_size)
|
34 |
+
fill = tuple(np.array(image).mean(axis=(0, 1)).astype(np.uint8).tolist())
|
35 |
+
image, offset = random_crop(image=image, crop_size=crop_size, fill=fill)
|
36 |
+
|
37 |
+
if mask is not None:
|
38 |
+
assert ignore_index is not None
|
39 |
+
mask = random_crop(
|
40 |
+
image=mask, crop_size=crop_size, fill=ignore_index, offset=offset
|
41 |
+
)[0]
|
42 |
+
|
43 |
+
if random_hflip_p is not None:
|
44 |
+
image, mask = random_hflip(image=image, p=random_hflip_p, mask=mask)
|
45 |
+
return image, mask
|
46 |
+
|
47 |
+
def photometric_augmentations(
|
48 |
+
image: Image.Image,
|
49 |
+
random_color_jitter: bool,
|
50 |
+
random_grayscale: bool,
|
51 |
+
random_gaussian_blur: bool,
|
52 |
+
proba_photometric_aug: float,
|
53 |
+
) -> torch.Tensor:
|
54 |
+
if random_color_jitter:
|
55 |
+
color_jitter = ColorJitter(
|
56 |
+
brightness=0.8, contrast=0.8, saturation=0.8, hue=0.2
|
57 |
+
)
|
58 |
+
image = RandomApply([color_jitter], p=proba_photometric_aug)(image)
|
59 |
+
|
60 |
+
if random_grayscale:
|
61 |
+
image = RandomGrayscale(proba_photometric_aug)(image)
|
62 |
+
|
63 |
+
if random_gaussian_blur:
|
64 |
+
w, h = image.size
|
65 |
+
image = GaussianBlur(kernel_size=int((0.1 * min(w, h) // 2 * 2) + 1))(
|
66 |
+
image, proba_photometric_aug
|
67 |
+
)
|
68 |
+
return image
|
datasets/datasets.py
ADDED
@@ -0,0 +1,409 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Dataset functions for applying Normalized Cut.
|
3 |
+
Code adapted from SelfMask: https://github.com/NoelShin/selfmask
|
4 |
+
"""
|
5 |
+
|
6 |
+
import os
|
7 |
+
from typing import Optional, Tuple, Union
|
8 |
+
|
9 |
+
from pycocotools.coco import COCO
|
10 |
+
import numpy as np
|
11 |
+
import torch
|
12 |
+
import torchvision
|
13 |
+
from PIL import Image
|
14 |
+
from torch.utils.data import Dataset
|
15 |
+
from torchvision import transforms as T
|
16 |
+
|
17 |
+
from datasets.utils import unnormalize
|
18 |
+
from datasets.geometric_transforms import resize
|
19 |
+
from datasets.VOC import get_voc_detection_gt, create_gt_masks_if_voc, create_VOC_loader
|
20 |
+
from datasets.augmentations import geometric_augmentations, photometric_augmentations
|
21 |
+
|
22 |
+
from datasets.uod_datasets import UODDataset
|
23 |
+
|
24 |
+
NORMALIZE = T.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))
|
25 |
+
|
26 |
+
def set_dataset_dir(dataset_name, root_dir):
|
27 |
+
if dataset_name == "ECSSD":
|
28 |
+
dataset_dir = os.path.join(root_dir, "ECSSD")
|
29 |
+
img_dir = os.path.join(dataset_dir, "images")
|
30 |
+
gt_dir = os.path.join(dataset_dir, "ground_truth_mask")
|
31 |
+
|
32 |
+
elif dataset_name == "DUTS-TEST":
|
33 |
+
dataset_dir = os.path.join(root_dir, "DUTS")
|
34 |
+
img_dir = os.path.join(dataset_dir, "DUTS-TE-Image")
|
35 |
+
gt_dir = os.path.join(dataset_dir, "DUTS-TE-Mask")
|
36 |
+
|
37 |
+
elif dataset_name == "DUTS-TR":
|
38 |
+
dataset_dir = os.path.join(root_dir, "DUTS")
|
39 |
+
img_dir = os.path.join(dataset_dir, "DUTS-TR-Image")
|
40 |
+
gt_dir = os.path.join(dataset_dir, "DUTS-TR-Mask")
|
41 |
+
|
42 |
+
elif dataset_name == "DUT-OMRON":
|
43 |
+
dataset_dir = os.path.join(root_dir, "DUT-OMRON")
|
44 |
+
img_dir = os.path.join(dataset_dir, "DUT-OMRON-image")
|
45 |
+
gt_dir = os.path.join(dataset_dir, "pixelwiseGT-new-PNG")
|
46 |
+
|
47 |
+
elif dataset_name == "VOC07":
|
48 |
+
dataset_dir = os.path.join(root_dir, "VOC2007")
|
49 |
+
img_dir = dataset_dir
|
50 |
+
gt_dir = dataset_dir
|
51 |
+
|
52 |
+
elif dataset_name == "VOC12":
|
53 |
+
dataset_dir = os.path.join('/datasets_local/osimeoni', "VOC2012")
|
54 |
+
img_dir = dataset_dir
|
55 |
+
gt_dir = dataset_dir
|
56 |
+
|
57 |
+
elif dataset_name == "COCO17":
|
58 |
+
dataset_dir = os.path.join(root_dir, "COCO")
|
59 |
+
img_dir = dataset_dir
|
60 |
+
gt_dir = dataset_dir
|
61 |
+
|
62 |
+
elif dataset_name == "ImageNet":
|
63 |
+
dataset_dir = os.path.join(root_dir, "ImageNet")
|
64 |
+
img_dir = dataset_dir
|
65 |
+
gt_dir = dataset_dir
|
66 |
+
|
67 |
+
else:
|
68 |
+
raise ValueError(f"Unknown dataset {dataset_name}")
|
69 |
+
|
70 |
+
return img_dir, gt_dir
|
71 |
+
|
72 |
+
|
73 |
+
def build_dataset(
|
74 |
+
root_dir: str,
|
75 |
+
dataset_name: str,
|
76 |
+
dataset_set: Optional[str] = None,
|
77 |
+
for_eval: bool = False,
|
78 |
+
config=None,
|
79 |
+
evaluation_type="saliency", # uod,
|
80 |
+
):
|
81 |
+
"""
|
82 |
+
Build dataset
|
83 |
+
"""
|
84 |
+
|
85 |
+
if evaluation_type == "saliency":
|
86 |
+
img_dir, gt_dir = set_dataset_dir(dataset_name, root_dir)
|
87 |
+
|
88 |
+
dataset = FoundDataset(
|
89 |
+
name=dataset_name,
|
90 |
+
img_dir=img_dir,
|
91 |
+
gt_dir=gt_dir,
|
92 |
+
dataset_set=dataset_set,
|
93 |
+
config=config,
|
94 |
+
for_eval=for_eval,
|
95 |
+
evaluation_type=evaluation_type,
|
96 |
+
)
|
97 |
+
|
98 |
+
elif evaluation_type == "uod":
|
99 |
+
assert dataset_name in ["VOC07", "VOC12", "COCO20k"]
|
100 |
+
dataset_set = "trainval" if dataset_name in ["VOC07", "VOC12"] else "train"
|
101 |
+
no_hards = False
|
102 |
+
dataset = UODDataset(
|
103 |
+
dataset_name,
|
104 |
+
dataset_set,
|
105 |
+
root_dir=root_dir,
|
106 |
+
remove_hards=no_hards,
|
107 |
+
)
|
108 |
+
|
109 |
+
return dataset
|
110 |
+
|
111 |
+
|
112 |
+
class FoundDataset(Dataset):
|
113 |
+
def __init__(
|
114 |
+
self,
|
115 |
+
name: str,
|
116 |
+
img_dir: str,
|
117 |
+
gt_dir: str,
|
118 |
+
dataset_set: Optional[str] = None,
|
119 |
+
config=None,
|
120 |
+
for_eval:bool = False,
|
121 |
+
evaluation_type:str = "saliency",
|
122 |
+
) -> None:
|
123 |
+
"""
|
124 |
+
Args:
|
125 |
+
root_dir (string): Directory with all the images.
|
126 |
+
transform (callable, optional): Optional transform to be applied
|
127 |
+
on a sample.
|
128 |
+
"""
|
129 |
+
self.for_eval = for_eval
|
130 |
+
self.use_aug = not for_eval
|
131 |
+
self.evaluation_type = evaluation_type
|
132 |
+
|
133 |
+
assert evaluation_type in ["saliency"]
|
134 |
+
|
135 |
+
self.name = name
|
136 |
+
self.dataset_set = dataset_set
|
137 |
+
self.img_dir = img_dir
|
138 |
+
self.gt_dir = gt_dir
|
139 |
+
|
140 |
+
# if VOC dataset
|
141 |
+
self.loader = None
|
142 |
+
self.cocoGt = None
|
143 |
+
|
144 |
+
self.config = config
|
145 |
+
|
146 |
+
if "VOC" in self.name:
|
147 |
+
self.loader = create_VOC_loader(self.img_dir, dataset_set, evaluation_type)
|
148 |
+
|
149 |
+
# if ImageNet dataset
|
150 |
+
elif "ImageNet" in self.name:
|
151 |
+
self.loader = torchvision.datasets.ImageNet(
|
152 |
+
self.img_dir,
|
153 |
+
split=dataset_set,
|
154 |
+
transform=None,
|
155 |
+
target_transform=None,
|
156 |
+
)
|
157 |
+
|
158 |
+
elif "COCO" in self.name:
|
159 |
+
year = int("20"+self.name[-2:])
|
160 |
+
annFile=f'/datasets_local/COCO/annotations/instances_{dataset_set}{str(year)}.json'
|
161 |
+
self.cocoGt=COCO(annFile)
|
162 |
+
self.img_ids = list(sorted(self.cocoGt.getImgIds()))
|
163 |
+
self.img_dir = f'/datasets_local/COCO/images/{dataset_set}{str(year)}/'
|
164 |
+
|
165 |
+
# Transformations
|
166 |
+
if self.for_eval:
|
167 |
+
full_img_transform, no_norm_full_img_transform = self.get_init_transformation(
|
168 |
+
isVOC="VOC" in name
|
169 |
+
)
|
170 |
+
self.full_img_transform = full_img_transform
|
171 |
+
self.no_norm_full_img_transform = no_norm_full_img_transform
|
172 |
+
|
173 |
+
# Images
|
174 |
+
self.list_images = None
|
175 |
+
if not "VOC" in self.name and not "COCO" in self.name:
|
176 |
+
self.list_images = [
|
177 |
+
os.path.join(img_dir, i) for i in sorted(os.listdir(img_dir))
|
178 |
+
]
|
179 |
+
|
180 |
+
self.ignore_index = -1
|
181 |
+
self.mean = NORMALIZE.mean
|
182 |
+
self.std = NORMALIZE.std
|
183 |
+
self.to_tensor_and_normalize = T.Compose([T.ToTensor(), NORMALIZE])
|
184 |
+
self.normalize = NORMALIZE
|
185 |
+
|
186 |
+
if config is not None and self.use_aug:
|
187 |
+
self._set_aug(config)
|
188 |
+
|
189 |
+
|
190 |
+
def get_init_transformation(self, isVOC: bool = False):
|
191 |
+
if isVOC:
|
192 |
+
t = T.Compose([T.PILToTensor(), T.ConvertImageDtype(torch.float), NORMALIZE])
|
193 |
+
t_nonorm = T.Compose([T.PILToTensor(), T.ConvertImageDtype(torch.float)])
|
194 |
+
return t, t_nonorm
|
195 |
+
|
196 |
+
else:
|
197 |
+
t = T.Compose([T.ToTensor(), NORMALIZE])
|
198 |
+
t_nonorm = T.Compose([T.ToTensor()])
|
199 |
+
return t, t_nonorm
|
200 |
+
|
201 |
+
def _set_aug(self, config):
|
202 |
+
"""
|
203 |
+
Set augmentation based on config.
|
204 |
+
"""
|
205 |
+
|
206 |
+
photometric_aug = config.training["photometric_aug"]
|
207 |
+
|
208 |
+
self.cropping_strategy = config.training["cropping_strategy"]
|
209 |
+
if self.cropping_strategy == "center_crop":
|
210 |
+
self.use_aug = False # default strategy, not considered to be a data aug
|
211 |
+
self.scale_range = config.training["scale_range"]
|
212 |
+
self.crop_size = config.training["crop_size"]
|
213 |
+
self.center_crop_transforms = T.Compose(
|
214 |
+
[
|
215 |
+
T.CenterCrop((self.crop_size, self.crop_size)),
|
216 |
+
T.ToTensor(),
|
217 |
+
]
|
218 |
+
)
|
219 |
+
self.center_crop_only_transforms = T.Compose(
|
220 |
+
[T.CenterCrop((self.crop_size, self.crop_size)), T.PILToTensor()]
|
221 |
+
)
|
222 |
+
|
223 |
+
self.proba_photometric_aug = config.training["proba_photometric_aug"]
|
224 |
+
|
225 |
+
self.random_color_jitter = False
|
226 |
+
self.random_grayscale = False
|
227 |
+
self.random_gaussian_blur = False
|
228 |
+
if photometric_aug == "color_jitter":
|
229 |
+
self.random_color_jitter = True
|
230 |
+
elif photometric_aug == "grayscale":
|
231 |
+
self.random_grayscale = True
|
232 |
+
elif photometric_aug == "gaussian_blur":
|
233 |
+
self.random_gaussian_blur = True
|
234 |
+
|
235 |
+
def _preprocess_data_aug(
|
236 |
+
self,
|
237 |
+
image: Image.Image,
|
238 |
+
mask: Image.Image,
|
239 |
+
ignore_index: Optional[int] = None,
|
240 |
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
241 |
+
"""Prepare data in a proper form for either training (data augmentation) or validation."""
|
242 |
+
|
243 |
+
# resize to base size
|
244 |
+
image = resize(
|
245 |
+
image,
|
246 |
+
size=self.crop_size,
|
247 |
+
edge="shorter",
|
248 |
+
interpolation="bilinear",
|
249 |
+
)
|
250 |
+
mask = resize(
|
251 |
+
mask,
|
252 |
+
size=self.crop_size,
|
253 |
+
edge="shorter",
|
254 |
+
interpolation="bilinear",
|
255 |
+
)
|
256 |
+
|
257 |
+
if not isinstance(mask, torch.Tensor):
|
258 |
+
mask: torch.Tensor = torch.tensor(np.array(mask))
|
259 |
+
|
260 |
+
random_scale_range = None
|
261 |
+
random_crop_size = None
|
262 |
+
random_hflip_p = None
|
263 |
+
if self.cropping_strategy == "random_scale":
|
264 |
+
random_scale_range = self.scale_range
|
265 |
+
elif self.cropping_strategy == "random_crop":
|
266 |
+
random_crop_size = self.crop_size
|
267 |
+
elif self.cropping_strategy == "random_hflip":
|
268 |
+
random_hflip_p = 0.5
|
269 |
+
elif self.cropping_strategy == "random_crop_and_hflip":
|
270 |
+
random_hflip_p = 0.5
|
271 |
+
random_crop_size = self.crop_size
|
272 |
+
|
273 |
+
if random_crop_size or random_hflip_p or random_scale_range:
|
274 |
+
image, mask = geometric_augmentations(
|
275 |
+
image=image,
|
276 |
+
mask=mask,
|
277 |
+
random_scale_range=random_scale_range,
|
278 |
+
random_crop_size=random_crop_size,
|
279 |
+
ignore_index=ignore_index,
|
280 |
+
random_hflip_p=random_hflip_p,
|
281 |
+
)
|
282 |
+
|
283 |
+
if random_scale_range:
|
284 |
+
# resize to (self.crop_size, self.crop_size)
|
285 |
+
image = resize(
|
286 |
+
image,
|
287 |
+
size=self.crop_size,
|
288 |
+
interpolation="bilinear",
|
289 |
+
)
|
290 |
+
mask = resize(
|
291 |
+
mask,
|
292 |
+
size=(self.crop_size, self.crop_size),
|
293 |
+
interpolation="bilinear",
|
294 |
+
)
|
295 |
+
|
296 |
+
image = photometric_augmentations(
|
297 |
+
image,
|
298 |
+
random_color_jitter=self.random_color_jitter,
|
299 |
+
random_grayscale=self.random_grayscale,
|
300 |
+
random_gaussian_blur=self.random_gaussian_blur,
|
301 |
+
proba_photometric_aug=self.proba_photometric_aug,
|
302 |
+
)
|
303 |
+
|
304 |
+
# to tensor + normalize image
|
305 |
+
image = self.to_tensor_and_normalize(image)
|
306 |
+
|
307 |
+
return image, mask
|
308 |
+
|
309 |
+
def __len__(self) -> int:
|
310 |
+
if "VOC" in self.name:
|
311 |
+
return len(self.loader)
|
312 |
+
elif "ImageNet" in self.name:
|
313 |
+
return len(self.loader)
|
314 |
+
elif "COCO" in self.name:
|
315 |
+
return len(self.img_ids)
|
316 |
+
return len(self.list_images)
|
317 |
+
|
318 |
+
def _apply_center_crop(
|
319 |
+
self, image: Image.Image, mask: Union[Image.Image, np.ndarray, torch.Tensor]
|
320 |
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
321 |
+
img_t = self.center_crop_transforms(image)
|
322 |
+
# need to normalize image
|
323 |
+
img_t = self.normalize(img_t)
|
324 |
+
mask_gt = self.center_crop_transforms(mask).squeeze()
|
325 |
+
return img_t, mask_gt
|
326 |
+
|
327 |
+
|
328 |
+
def __getitem__(self, idx, get_mask_gt=True):
|
329 |
+
if "VOC" in self.name:
|
330 |
+
img, gt_labels = self.loader[idx]
|
331 |
+
if self.evaluation_type == "uod":
|
332 |
+
gt_labels, _ = get_voc_detection_gt(
|
333 |
+
gt_labels, remove_hards=False
|
334 |
+
)
|
335 |
+
elif self.evaluation_type == "saliency":
|
336 |
+
mask_gt = create_gt_masks_if_voc(gt_labels)
|
337 |
+
img_path = self.loader.images[idx]
|
338 |
+
|
339 |
+
elif "ImageNet" in self.name:
|
340 |
+
img, _ = self.loader[idx]
|
341 |
+
img_path = self.loader.imgs[idx][0]
|
342 |
+
# empty mask since no gt mask, only class label
|
343 |
+
zeros = np.zeros(np.array(img).shape[:2])
|
344 |
+
mask_gt = Image.fromarray(zeros)
|
345 |
+
|
346 |
+
elif "COCO" in self.name:
|
347 |
+
img_id = self.img_ids[idx]
|
348 |
+
|
349 |
+
path = self.cocoGt.loadImgs(img_id)[0]["file_name"]
|
350 |
+
img = Image.open(os.path.join(self.img_dir, path)).convert("RGB")
|
351 |
+
_ = self.cocoGt.loadAnns(self.cocoGt.getAnnIds(id))
|
352 |
+
img_path = self.img_ids[idx] # What matters most is the id for eval
|
353 |
+
|
354 |
+
# empty mask since no gt mask, only class label
|
355 |
+
zeros = np.zeros(np.array(img).shape[:2])
|
356 |
+
mask_gt = Image.fromarray(zeros)
|
357 |
+
|
358 |
+
# For all others
|
359 |
+
else:
|
360 |
+
img_path = self.list_images[idx]
|
361 |
+
with open(img_path, "rb") as f:
|
362 |
+
img = Image.open(f)
|
363 |
+
img = img.convert("RGB")
|
364 |
+
im_name = img_path.split("/")[-1]
|
365 |
+
mask_gt = Image.open(
|
366 |
+
os.path.join(self.gt_dir, im_name.replace(".jpg", ".png"))
|
367 |
+
).convert("L")
|
368 |
+
|
369 |
+
if self.for_eval:
|
370 |
+
img_t = self.full_img_transform(img)
|
371 |
+
img_init = self.no_norm_full_img_transform(img)
|
372 |
+
|
373 |
+
if self.evaluation_type == "saliency":
|
374 |
+
mask_gt = torch.tensor(np.array(mask_gt)).squeeze()
|
375 |
+
mask_gt = np.array(mask_gt)
|
376 |
+
mask_gt = mask_gt == 255
|
377 |
+
mask_gt = torch.tensor(mask_gt)
|
378 |
+
else:
|
379 |
+
if self.use_aug:
|
380 |
+
img_t, mask_gt = self._preprocess_data_aug(
|
381 |
+
image=img, mask=mask_gt, ignore_index=self.ignore_index
|
382 |
+
)
|
383 |
+
mask_gt = np.array(mask_gt)
|
384 |
+
mask_gt = mask_gt == 255
|
385 |
+
mask_gt = torch.tensor(mask_gt)
|
386 |
+
else:
|
387 |
+
# no data aug
|
388 |
+
img_t, mask_gt = self._apply_center_crop(image=img, mask=mask_gt)
|
389 |
+
gt_labels = self.center_crop_only_transforms(gt_labels).squeeze()
|
390 |
+
mask_gt = np.asarray(mask_gt, np.int64)
|
391 |
+
mask_gt = mask_gt == 1
|
392 |
+
mask_gt = torch.tensor(mask_gt)
|
393 |
+
|
394 |
+
img_init = unnormalize(img_t)
|
395 |
+
|
396 |
+
if not get_mask_gt:
|
397 |
+
mask_gt = None
|
398 |
+
|
399 |
+
if self.evaluation_type == "uod":
|
400 |
+
gt_labels = torch.tensor(gt_labels)
|
401 |
+
mask_gt = gt_labels
|
402 |
+
|
403 |
+
return img_t, img_init, mask_gt, img_path
|
404 |
+
|
405 |
+
def fullimg_mode(self):
|
406 |
+
self.val_full_image = True
|
407 |
+
|
408 |
+
def training_mode(self):
|
409 |
+
self.val_full_image = False
|
datasets/geometric_transforms.py
ADDED
@@ -0,0 +1,160 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Code adapted from SelfMask: https://github.com/NoelShin/selfmask
|
3 |
+
"""
|
4 |
+
|
5 |
+
from random import randint, random, uniform
|
6 |
+
from typing import Optional, Tuple, Union
|
7 |
+
|
8 |
+
import numpy as np
|
9 |
+
import torch
|
10 |
+
import torchvision.transforms.functional as TF
|
11 |
+
from PIL import Image
|
12 |
+
from torchvision.transforms.functional import InterpolationMode as IM
|
13 |
+
|
14 |
+
|
15 |
+
def random_crop(
|
16 |
+
image: Union[Image.Image, np.ndarray, torch.Tensor],
|
17 |
+
crop_size: Tuple[int, int], # (h, w)
|
18 |
+
fill: Union[int, Tuple[int, int, int]], # an unsigned integer or RGB,
|
19 |
+
offset: Optional[Tuple[int, int]] = None, # (top, left) coordinate of a crop
|
20 |
+
):
|
21 |
+
assert type(crop_size) in (tuple, list) and len(crop_size) == 2
|
22 |
+
|
23 |
+
if isinstance(image, np.ndarray):
|
24 |
+
image = torch.tensor(image)
|
25 |
+
h, w = image.shape[-2:]
|
26 |
+
elif isinstance(image, Image.Image):
|
27 |
+
w, h = image.size
|
28 |
+
elif isinstance(image, torch.Tensor):
|
29 |
+
h, w = image.shape[-2:]
|
30 |
+
else:
|
31 |
+
raise TypeError(type(image))
|
32 |
+
|
33 |
+
pad_h, pad_w = max(crop_size[0] - h, 0), max(crop_size[1] - w, 0)
|
34 |
+
|
35 |
+
image = TF.pad(image, [0, 0, pad_w, pad_h], fill=fill, padding_mode="constant")
|
36 |
+
|
37 |
+
if isinstance(image, Image.Image):
|
38 |
+
w, h = image.size
|
39 |
+
else:
|
40 |
+
h, w = image.shape[-2:]
|
41 |
+
|
42 |
+
if offset is None:
|
43 |
+
offset = (randint(0, h - crop_size[0]), randint(0, w - crop_size[1]))
|
44 |
+
|
45 |
+
image = TF.crop(
|
46 |
+
image, top=offset[0], left=offset[1], height=crop_size[0], width=crop_size[1]
|
47 |
+
)
|
48 |
+
return image, offset
|
49 |
+
|
50 |
+
|
51 |
+
def compute_size(
|
52 |
+
input_size: Tuple[int, int], output_size: int, edge: str # h, w
|
53 |
+
) -> Tuple[int, int]:
|
54 |
+
assert edge in ["shorter", "longer"]
|
55 |
+
h, w = input_size
|
56 |
+
|
57 |
+
if edge == "longer":
|
58 |
+
if w > h:
|
59 |
+
h = int(float(h) / w * output_size)
|
60 |
+
w = output_size
|
61 |
+
else:
|
62 |
+
w = int(float(w) / h * output_size)
|
63 |
+
h = output_size
|
64 |
+
assert w <= output_size and h <= output_size
|
65 |
+
|
66 |
+
else:
|
67 |
+
if w > h:
|
68 |
+
w = int(float(w) / h * output_size)
|
69 |
+
h = output_size
|
70 |
+
else:
|
71 |
+
h = int(float(h) / w * output_size)
|
72 |
+
w = output_size
|
73 |
+
assert w >= output_size and h >= output_size
|
74 |
+
return h, w
|
75 |
+
|
76 |
+
|
77 |
+
def resize(
|
78 |
+
image: Union[Image.Image, np.ndarray, torch.Tensor],
|
79 |
+
size: Union[int, Tuple[int, int]],
|
80 |
+
interpolation: str,
|
81 |
+
edge: str = "both",
|
82 |
+
) -> Union[Image.Image, torch.Tensor]:
|
83 |
+
"""
|
84 |
+
:param image: an image to be resized
|
85 |
+
:param size: a resulting image size
|
86 |
+
:param interpolation: sampling mode. ["nearest", "bilinear", "bicubic"]
|
87 |
+
:param edge: Default: "both"
|
88 |
+
No-op if a size is given as a tuple (h, w).
|
89 |
+
If set to "both", resize both height and width to the specified size.
|
90 |
+
If set to "shorter", resize the shorter edge to the specified size keeping the aspect ratio.
|
91 |
+
If set to "longer", resize the longer edge to the specified size keeping the aspect ratio.
|
92 |
+
:return: a resized image
|
93 |
+
"""
|
94 |
+
assert interpolation in ["nearest", "bilinear", "bicubic"], ValueError(
|
95 |
+
interpolation
|
96 |
+
)
|
97 |
+
assert edge in ["both", "shorter", "longer"], ValueError(edge)
|
98 |
+
interpolation = {
|
99 |
+
"nearest": IM.NEAREST,
|
100 |
+
"bilinear": IM.BILINEAR,
|
101 |
+
"bicubic": IM.BICUBIC,
|
102 |
+
}[interpolation]
|
103 |
+
|
104 |
+
if type(image) == torch.Tensor:
|
105 |
+
image = image.clone().detach()
|
106 |
+
elif type(image) == np.ndarray:
|
107 |
+
image = torch.from_numpy(image)
|
108 |
+
|
109 |
+
if type(size) is tuple:
|
110 |
+
if type(image) == torch.Tensor and len(image.shape) == 2:
|
111 |
+
image = TF.resize(
|
112 |
+
image.unsqueeze(dim=0), size=size, interpolation=interpolation
|
113 |
+
).squeeze(dim=0)
|
114 |
+
else:
|
115 |
+
image = TF.resize(image, size=size, interpolation=interpolation)
|
116 |
+
|
117 |
+
else:
|
118 |
+
if edge == "both":
|
119 |
+
image = TF.resize(image, size=[size, size], interpolation=interpolation)
|
120 |
+
|
121 |
+
else:
|
122 |
+
if isinstance(image, Image.Image):
|
123 |
+
w, h = image.size
|
124 |
+
else:
|
125 |
+
h, w = image.shape[-2:]
|
126 |
+
rh, rw = compute_size(input_size=(h, w), output_size=size, edge=edge)
|
127 |
+
image = TF.resize(image, size=[rh, rw], interpolation=interpolation)
|
128 |
+
return image
|
129 |
+
|
130 |
+
|
131 |
+
def random_scale(
|
132 |
+
image: Union[Image.Image, np.ndarray, torch.Tensor],
|
133 |
+
random_scale_range: Tuple[float, float],
|
134 |
+
mask: Optional[Union[Image.Image, np.ndarray, torch.Tensor]] = None,
|
135 |
+
):
|
136 |
+
scale = uniform(*random_scale_range)
|
137 |
+
if isinstance(image, Image.Image):
|
138 |
+
w, h = image.size
|
139 |
+
else:
|
140 |
+
h, w = image.shape[-2:]
|
141 |
+
w_rs, h_rs = int(w * scale), int(h * scale)
|
142 |
+
image: Image.Image = resize(image, size=(h_rs, w_rs), interpolation="bilinear")
|
143 |
+
if mask is not None:
|
144 |
+
mask = resize(mask, size=(h_rs, w_rs), interpolation="nearest")
|
145 |
+
return image, mask
|
146 |
+
|
147 |
+
|
148 |
+
def random_hflip(
|
149 |
+
image: Union[Image.Image, np.ndarray, torch.Tensor],
|
150 |
+
p: float,
|
151 |
+
mask: Optional[Union[np.ndarray, torch.Tensor]] = None,
|
152 |
+
):
|
153 |
+
assert 0.0 <= p <= 1.0, ValueError(random_hflip)
|
154 |
+
|
155 |
+
# Return a random floating point number in the range [0.0, 1.0).
|
156 |
+
if random() > p:
|
157 |
+
image = TF.hflip(image)
|
158 |
+
if mask is not None:
|
159 |
+
mask = TF.hflip(mask)
|
160 |
+
return image, mask
|
datasets/uod_datasets.py
ADDED
@@ -0,0 +1,384 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2021 - Valeo Comfort and Driving Assistance - Oriane Siméoni @ valeo.ai
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
|
15 |
+
"""
|
16 |
+
Code adapted from previous method LOST: https://github.com/valeoai/LOST
|
17 |
+
"""
|
18 |
+
|
19 |
+
import os
|
20 |
+
import math
|
21 |
+
import torch
|
22 |
+
import json
|
23 |
+
import torchvision
|
24 |
+
import numpy as np
|
25 |
+
import skimage.io
|
26 |
+
|
27 |
+
from PIL import Image
|
28 |
+
from tqdm import tqdm
|
29 |
+
from torchvision import transforms as pth_transforms
|
30 |
+
|
31 |
+
# Image transformation applied to all images
|
32 |
+
transform = pth_transforms.Compose(
|
33 |
+
[
|
34 |
+
pth_transforms.ToTensor(),
|
35 |
+
pth_transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
|
36 |
+
]
|
37 |
+
)
|
38 |
+
|
39 |
+
class ImageDataset:
|
40 |
+
def __init__(
|
41 |
+
self,
|
42 |
+
image_path
|
43 |
+
):
|
44 |
+
|
45 |
+
self.image_path = image_path
|
46 |
+
self.name = image_path.split("/")[-1]
|
47 |
+
|
48 |
+
# Read the image
|
49 |
+
with open(image_path, "rb") as f:
|
50 |
+
img = Image.open(f)
|
51 |
+
img = img.convert("RGB")
|
52 |
+
|
53 |
+
# Build a dataloader
|
54 |
+
img = transform(img)
|
55 |
+
self.dataloader = [[img, image_path]]
|
56 |
+
|
57 |
+
def get_image_name(self, *args, **kwargs):
|
58 |
+
return self.image_path.split("/")[-1].split(".")[0]
|
59 |
+
|
60 |
+
def load_image(self, *args, **kwargs):
|
61 |
+
return skimage.io.imread(self.image_path)
|
62 |
+
|
63 |
+
class UODDataset:
|
64 |
+
def __init__(
|
65 |
+
self,
|
66 |
+
dataset_name,
|
67 |
+
dataset_set,
|
68 |
+
root_dir,
|
69 |
+
remove_hards:bool = False,
|
70 |
+
):
|
71 |
+
"""
|
72 |
+
Build the dataloader
|
73 |
+
"""
|
74 |
+
|
75 |
+
self.dataset_name = dataset_name
|
76 |
+
self.set = dataset_set
|
77 |
+
self.root_dir = root_dir
|
78 |
+
|
79 |
+
if dataset_name == "VOC07":
|
80 |
+
self.root_path = f"{root_dir}/VOC2007"
|
81 |
+
self.year = "2007"
|
82 |
+
elif dataset_name == "VOC12":
|
83 |
+
self.root_path = f"{root_dir}/VOC2012"
|
84 |
+
self.year = "2012"
|
85 |
+
elif dataset_name == "COCO20k":
|
86 |
+
self.year = "2014"
|
87 |
+
self.root_path = f"{root_dir}/COCO/images/{dataset_set}{self.year}"
|
88 |
+
self.sel20k = 'data/coco_20k_filenames.txt'
|
89 |
+
# JSON file constructed based on COCO train2014 gt
|
90 |
+
self.all_annfile = f"{root_dir}/COCO/annotations/instances_train2014.json"
|
91 |
+
self.annfile = f"{root_dir}/instances_train2014_sel20k.json"
|
92 |
+
if not os.path.exists(self.annfile):
|
93 |
+
select_coco_20k(self.sel20k, self.all_annfile)
|
94 |
+
else:
|
95 |
+
raise ValueError("Unknown dataset.")
|
96 |
+
|
97 |
+
if not os.path.exists(self.root_path):
|
98 |
+
raise ValueError("Please follow the README to setup the datasets.")
|
99 |
+
|
100 |
+
self.name = f"{self.dataset_name}_{self.set}"
|
101 |
+
|
102 |
+
# Build the dataloader
|
103 |
+
if "VOC" in dataset_name:
|
104 |
+
self.dataloader = torchvision.datasets.VOCDetection(
|
105 |
+
self.root_path,
|
106 |
+
year=self.year,
|
107 |
+
image_set=self.set,
|
108 |
+
transform=transform,
|
109 |
+
download=False,
|
110 |
+
)
|
111 |
+
elif "COCO20k" == dataset_name:
|
112 |
+
self.dataloader = torchvision.datasets.CocoDetection(
|
113 |
+
self.root_path, annFile=self.annfile, transform=transform
|
114 |
+
)
|
115 |
+
else:
|
116 |
+
raise ValueError("Unknown dataset.")
|
117 |
+
|
118 |
+
# Set hards images that are not included
|
119 |
+
self.remove_hards = remove_hards
|
120 |
+
self.hards = []
|
121 |
+
if remove_hards:
|
122 |
+
self.name += f"-nohards"
|
123 |
+
self.hards = self.get_hards()
|
124 |
+
print(f"Nb images discarded {len(self.hards)}")
|
125 |
+
|
126 |
+
def __len__(self) -> int:
|
127 |
+
return len(self.dataloader)
|
128 |
+
|
129 |
+
def load_image(self, im_name):
|
130 |
+
"""
|
131 |
+
Load the image corresponding to the im_name
|
132 |
+
"""
|
133 |
+
if "VOC" in self.dataset_name:
|
134 |
+
image = skimage.io.imread(f"{self.root_dir}/VOC{self.year}/JPEGImages/{im_name}")
|
135 |
+
elif "COCO" in self.dataset_name:
|
136 |
+
im_path = self.path_20k[self.sel_20k.index(im_name)]
|
137 |
+
image = skimage.io.imread(f"{self.root_dir}/COCO/images/{im_path}")
|
138 |
+
else:
|
139 |
+
raise ValueError("Unkown dataset.")
|
140 |
+
return image
|
141 |
+
|
142 |
+
def get_image_name(self, inp):
|
143 |
+
"""
|
144 |
+
Return the image name
|
145 |
+
"""
|
146 |
+
if "VOC" in self.dataset_name:
|
147 |
+
im_name = inp["annotation"]["filename"]
|
148 |
+
elif "COCO" in self.dataset_name:
|
149 |
+
im_name = str(inp[0]["image_id"])
|
150 |
+
|
151 |
+
return im_name
|
152 |
+
|
153 |
+
def extract_gt(self, targets, im_name):
|
154 |
+
if "VOC" in self.dataset_name:
|
155 |
+
return extract_gt_VOC(targets, remove_hards=self.remove_hards)
|
156 |
+
elif "COCO" in self.dataset_name:
|
157 |
+
return extract_gt_COCO(targets, remove_iscrowd=True)
|
158 |
+
else:
|
159 |
+
raise ValueError("Unknown dataset")
|
160 |
+
|
161 |
+
def extract_classes(self):
|
162 |
+
if "VOC" in self.dataset_name:
|
163 |
+
cls_path = f"classes_{self.set}_{self.year}.txt"
|
164 |
+
elif "COCO" in self.dataset_name:
|
165 |
+
cls_path = f"classes_{self.dataset}_{self.set}_{self.year}.txt"
|
166 |
+
|
167 |
+
# Load if exists
|
168 |
+
if os.path.exists(cls_path):
|
169 |
+
all_classes = []
|
170 |
+
with open(cls_path, "r") as f:
|
171 |
+
for line in f:
|
172 |
+
all_classes.append(line.strip())
|
173 |
+
else:
|
174 |
+
print("Extract all classes from the dataset")
|
175 |
+
if "VOC" in self.dataset_name:
|
176 |
+
all_classes = self.extract_classes_VOC()
|
177 |
+
elif "COCO" in self.dataset_name:
|
178 |
+
all_classes = self.extract_classes_COCO()
|
179 |
+
|
180 |
+
with open(cls_path, "w") as f:
|
181 |
+
for s in all_classes:
|
182 |
+
f.write(str(s) + "\n")
|
183 |
+
|
184 |
+
return all_classes
|
185 |
+
|
186 |
+
def extract_classes_VOC(self):
|
187 |
+
all_classes = []
|
188 |
+
for im_id, inp in enumerate(tqdm(self.dataloader)):
|
189 |
+
objects = inp[1]["annotation"]["object"]
|
190 |
+
|
191 |
+
for o in range(len(objects)):
|
192 |
+
if objects[o]["name"] not in all_classes:
|
193 |
+
all_classes.append(objects[o]["name"])
|
194 |
+
|
195 |
+
return all_classes
|
196 |
+
|
197 |
+
def extract_classes_COCO(self):
|
198 |
+
all_classes = []
|
199 |
+
for im_id, inp in enumerate(tqdm(self.dataloader)):
|
200 |
+
objects = inp[1]
|
201 |
+
|
202 |
+
for o in range(len(objects)):
|
203 |
+
if objects[o]["category_id"] not in all_classes:
|
204 |
+
all_classes.append(objects[o]["category_id"])
|
205 |
+
|
206 |
+
return all_classes
|
207 |
+
|
208 |
+
def get_hards(self):
|
209 |
+
hard_path = "datasets/hard_%s_%s_%s.txt" % (self.dataset_name, self.set, self.year)
|
210 |
+
if os.path.exists(hard_path):
|
211 |
+
hards = []
|
212 |
+
with open(hard_path, "r") as f:
|
213 |
+
for line in f:
|
214 |
+
hards.append(int(line.strip()))
|
215 |
+
else:
|
216 |
+
print("Discover hard images that should be discarded")
|
217 |
+
|
218 |
+
if "VOC" in self.dataset_name:
|
219 |
+
# set the hards
|
220 |
+
hards = discard_hard_voc(self.dataloader)
|
221 |
+
|
222 |
+
with open(hard_path, "w") as f:
|
223 |
+
for s in hards:
|
224 |
+
f.write(str(s) + "\n")
|
225 |
+
|
226 |
+
return hards
|
227 |
+
|
228 |
+
|
229 |
+
def discard_hard_voc(dataloader):
|
230 |
+
hards = []
|
231 |
+
for im_id, inp in enumerate(tqdm(dataloader)):
|
232 |
+
objects = inp[1]["annotation"]["object"]
|
233 |
+
nb_obj = len(objects)
|
234 |
+
|
235 |
+
hard = np.zeros(nb_obj)
|
236 |
+
for i, o in enumerate(range(nb_obj)):
|
237 |
+
hard[i] = (
|
238 |
+
1
|
239 |
+
if (objects[o]["truncated"] == "1" or objects[o]["difficult"] == "1")
|
240 |
+
else 0
|
241 |
+
)
|
242 |
+
|
243 |
+
# all images with only truncated or difficult objects
|
244 |
+
if np.sum(hard) == nb_obj:
|
245 |
+
hards.append(im_id)
|
246 |
+
return hards
|
247 |
+
|
248 |
+
|
249 |
+
def extract_gt_COCO(targets, remove_iscrowd=True):
|
250 |
+
objects = targets
|
251 |
+
nb_obj = len(objects)
|
252 |
+
|
253 |
+
gt_bbxs = []
|
254 |
+
gt_clss = []
|
255 |
+
for o in range(nb_obj):
|
256 |
+
# Remove iscrowd boxes
|
257 |
+
if remove_iscrowd and objects[o]["iscrowd"] == 1:
|
258 |
+
continue
|
259 |
+
gt_cls = objects[o]["category_id"]
|
260 |
+
gt_clss.append(gt_cls)
|
261 |
+
bbx = objects[o]["bbox"]
|
262 |
+
x1y1x2y2 = [bbx[0], bbx[1], bbx[0] + bbx[2], bbx[1] + bbx[3]]
|
263 |
+
x1y1x2y2 = [int(round(x)) for x in x1y1x2y2]
|
264 |
+
gt_bbxs.append(x1y1x2y2)
|
265 |
+
|
266 |
+
return np.asarray(gt_bbxs), gt_clss
|
267 |
+
|
268 |
+
|
269 |
+
def extract_gt_VOC(targets, remove_hards=False):
|
270 |
+
objects = targets["annotation"]["object"]
|
271 |
+
nb_obj = len(objects)
|
272 |
+
|
273 |
+
gt_bbxs = []
|
274 |
+
gt_clss = []
|
275 |
+
for o in range(nb_obj):
|
276 |
+
if remove_hards and (
|
277 |
+
objects[o]["truncated"] == "1" or objects[o]["difficult"] == "1"
|
278 |
+
):
|
279 |
+
continue
|
280 |
+
gt_cls = objects[o]["name"]
|
281 |
+
gt_clss.append(gt_cls)
|
282 |
+
obj = objects[o]["bndbox"]
|
283 |
+
x1y1x2y2 = [
|
284 |
+
int(obj["xmin"]),
|
285 |
+
int(obj["ymin"]),
|
286 |
+
int(obj["xmax"]),
|
287 |
+
int(obj["ymax"]),
|
288 |
+
]
|
289 |
+
# Original annotations are integers in the range [1, W or H]
|
290 |
+
# Assuming they mean 1-based pixel indices (inclusive),
|
291 |
+
# a box with annotation (xmin=1, xmax=W) covers the whole image.
|
292 |
+
# In coordinate space this is represented by (xmin=0, xmax=W)
|
293 |
+
x1y1x2y2[0] -= 1
|
294 |
+
x1y1x2y2[1] -= 1
|
295 |
+
gt_bbxs.append(x1y1x2y2)
|
296 |
+
|
297 |
+
return np.asarray(gt_bbxs), gt_clss
|
298 |
+
|
299 |
+
|
300 |
+
def bbox_iou(box1, box2, x1y1x2y2=True, GIoU=False, DIoU=False, CIoU=False, eps=1e-7):
|
301 |
+
# https://github.com/ultralytics/yolov5/blob/develop/utils/general.py
|
302 |
+
# Returns the IoU of box1 to box2. box1 is 4, box2 is nx4
|
303 |
+
box2 = box2.T
|
304 |
+
|
305 |
+
# Get the coordinates of bounding boxes
|
306 |
+
if x1y1x2y2: # x1, y1, x2, y2 = box1
|
307 |
+
b1_x1, b1_y1, b1_x2, b1_y2 = box1[0], box1[1], box1[2], box1[3]
|
308 |
+
b2_x1, b2_y1, b2_x2, b2_y2 = box2[0], box2[1], box2[2], box2[3]
|
309 |
+
else: # transform from xywh to xyxy
|
310 |
+
b1_x1, b1_x2 = box1[0] - box1[2] / 2, box1[0] + box1[2] / 2
|
311 |
+
b1_y1, b1_y2 = box1[1] - box1[3] / 2, box1[1] + box1[3] / 2
|
312 |
+
b2_x1, b2_x2 = box2[0] - box2[2] / 2, box2[0] + box2[2] / 2
|
313 |
+
b2_y1, b2_y2 = box2[1] - box2[3] / 2, box2[1] + box2[3] / 2
|
314 |
+
|
315 |
+
# Intersection area
|
316 |
+
inter = (torch.min(b1_x2, b2_x2) - torch.max(b1_x1, b2_x1)).clamp(0) * (
|
317 |
+
torch.min(b1_y2, b2_y2) - torch.max(b1_y1, b2_y1)
|
318 |
+
).clamp(0)
|
319 |
+
|
320 |
+
# Union Area
|
321 |
+
w1, h1 = b1_x2 - b1_x1, b1_y2 - b1_y1 + eps
|
322 |
+
w2, h2 = b2_x2 - b2_x1, b2_y2 - b2_y1 + eps
|
323 |
+
union = w1 * h1 + w2 * h2 - inter + eps
|
324 |
+
|
325 |
+
iou = inter / union
|
326 |
+
if GIoU or DIoU or CIoU:
|
327 |
+
cw = torch.max(b1_x2, b2_x2) - torch.min(
|
328 |
+
b1_x1, b2_x1
|
329 |
+
) # convex (smallest enclosing box) width
|
330 |
+
ch = torch.max(b1_y2, b2_y2) - torch.min(b1_y1, b2_y1) # convex height
|
331 |
+
if CIoU or DIoU: # Distance or Complete IoU https://arxiv.org/abs/1911.08287v1
|
332 |
+
c2 = cw ** 2 + ch ** 2 + eps # convex diagonal squared
|
333 |
+
rho2 = (
|
334 |
+
(b2_x1 + b2_x2 - b1_x1 - b1_x2) ** 2
|
335 |
+
+ (b2_y1 + b2_y2 - b1_y1 - b1_y2) ** 2
|
336 |
+
) / 4 # center distance squared
|
337 |
+
if DIoU:
|
338 |
+
return iou - rho2 / c2 # DIoU
|
339 |
+
elif (
|
340 |
+
CIoU
|
341 |
+
): # https://github.com/Zzh-tju/DIoU-SSD-pytorch/blob/master/utils/box/box_utils.py#L47
|
342 |
+
v = (4 / math.pi ** 2) * torch.pow(
|
343 |
+
torch.atan(w2 / h2) - torch.atan(w1 / h1), 2
|
344 |
+
)
|
345 |
+
with torch.no_grad():
|
346 |
+
alpha = v / (v - iou + (1 + eps))
|
347 |
+
return iou - (rho2 / c2 + v * alpha) # CIoU
|
348 |
+
else: # GIoU https://arxiv.org/pdf/1902.09630.pdf
|
349 |
+
c_area = cw * ch + eps # convex area
|
350 |
+
return iou - (c_area - union) / c_area # GIoU
|
351 |
+
else:
|
352 |
+
return iou # IoU
|
353 |
+
|
354 |
+
def select_coco_20k(sel_file, all_annotations_file):
|
355 |
+
print('Building COCO 20k dataset.')
|
356 |
+
|
357 |
+
# load all annotations
|
358 |
+
with open(all_annotations_file, "r") as f:
|
359 |
+
train2014 = json.load(f)
|
360 |
+
|
361 |
+
# load selected images
|
362 |
+
with open(sel_file, "r") as f:
|
363 |
+
sel_20k = f.readlines()
|
364 |
+
sel_20k = [s.replace("\n", "") for s in sel_20k]
|
365 |
+
im20k = [str(int(s.split("_")[-1].split(".")[0])) for s in sel_20k]
|
366 |
+
|
367 |
+
new_anno = []
|
368 |
+
new_images = []
|
369 |
+
|
370 |
+
for i in tqdm(im20k):
|
371 |
+
new_anno.extend(
|
372 |
+
[a for a in train2014["annotations"] if a["image_id"] == int(i)]
|
373 |
+
)
|
374 |
+
new_images.extend([a for a in train2014["images"] if a["id"] == int(i)])
|
375 |
+
|
376 |
+
train2014_20k = {}
|
377 |
+
train2014_20k["images"] = new_images
|
378 |
+
train2014_20k["annotations"] = new_anno
|
379 |
+
train2014_20k["categories"] = train2014["categories"]
|
380 |
+
|
381 |
+
with open("datasets/instances_train2014_sel20k.json", "w") as outfile:
|
382 |
+
json.dump(train2014_20k, outfile)
|
383 |
+
|
384 |
+
print('Done.')
|
datasets/utils.py
ADDED
@@ -0,0 +1,44 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
import torch
|
3 |
+
from PIL import Image
|
4 |
+
from torchvision import transforms as T
|
5 |
+
|
6 |
+
NORMALIZE = T.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))
|
7 |
+
|
8 |
+
class GaussianBlur:
|
9 |
+
"""
|
10 |
+
Code borrowed from SelfMask: https://github.com/NoelShin/selfmask
|
11 |
+
"""
|
12 |
+
|
13 |
+
# Implements Gaussian blur as described in the SimCLR paper
|
14 |
+
def __init__(self, kernel_size: float, min: float = 0.1, max: float = 2.0) -> None:
|
15 |
+
self.min = min
|
16 |
+
self.max = max
|
17 |
+
# kernel size is set to be 10% of the image height/width
|
18 |
+
self.kernel_size = kernel_size
|
19 |
+
|
20 |
+
def __call__(self, sample: Image.Image, random_gaussian_blur_p: float):
|
21 |
+
sample = np.array(sample)
|
22 |
+
|
23 |
+
# blur the image with a 50% chance
|
24 |
+
prob = np.random.random_sample()
|
25 |
+
|
26 |
+
if prob < 0.5:
|
27 |
+
import cv2
|
28 |
+
|
29 |
+
sigma = (self.max - self.min) * np.random.random_sample() + self.min
|
30 |
+
sample = cv2.GaussianBlur(
|
31 |
+
sample, (self.kernel_size, self.kernel_size), sigma
|
32 |
+
)
|
33 |
+
return sample
|
34 |
+
|
35 |
+
|
36 |
+
def unnormalize(image, mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)):
|
37 |
+
"""
|
38 |
+
Code borrowed from STEGO: https://github.com/mhamilton723/STEGO
|
39 |
+
"""
|
40 |
+
image2 = torch.clone(image)
|
41 |
+
for t, m, s in zip(image2, mean, std):
|
42 |
+
t.mul_(s).add_(m)
|
43 |
+
|
44 |
+
return image2
|
evaluation/__init__.py
ADDED
File without changes
|
evaluation/metrics/__init__.py
ADDED
File without changes
|
evaluation/metrics/average_meter.py
ADDED
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Code borrowed from SelfMask: https://github.com/NoelShin/selfmask
|
3 |
+
"""
|
4 |
+
|
5 |
+
class AverageMeter(object):
|
6 |
+
"""Computes and stores the average and current value"""
|
7 |
+
|
8 |
+
def __init__(self):
|
9 |
+
self.reset()
|
10 |
+
|
11 |
+
def reset(self):
|
12 |
+
self.val = 0
|
13 |
+
self.avg = 0
|
14 |
+
self.sum = 0
|
15 |
+
self.count = 0
|
16 |
+
|
17 |
+
def update(self, val, n: int):
|
18 |
+
self.val = val
|
19 |
+
self.sum += val * n
|
20 |
+
self.count += n
|
21 |
+
self.avg = self.sum / self.count
|
evaluation/metrics/f_measure.py
ADDED
@@ -0,0 +1,111 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Code borrowed from SelfMask: https://github.com/NoelShin/selfmask
|
3 |
+
"""
|
4 |
+
|
5 |
+
import torch
|
6 |
+
|
7 |
+
class FMeasure:
|
8 |
+
def __init__(
|
9 |
+
self,
|
10 |
+
default_thres: float = 0.5,
|
11 |
+
beta_square: float = 0.3,
|
12 |
+
n_bins: int = 255,
|
13 |
+
eps: float = 1e-7,
|
14 |
+
):
|
15 |
+
"""
|
16 |
+
:param default_thres: a hyperparameter for F-measure that is used to binarize a predicted mask. Default: 0.5
|
17 |
+
:param beta_square: a hyperparameter for F-measure. Default: 0.3
|
18 |
+
:param n_bins: the number of thresholds that will be tested for F-max. Default: 255
|
19 |
+
:param eps: a small value for numerical stability
|
20 |
+
"""
|
21 |
+
|
22 |
+
self.beta_square = beta_square
|
23 |
+
self.default_thres = default_thres
|
24 |
+
self.eps = eps
|
25 |
+
self.n_bins = n_bins
|
26 |
+
|
27 |
+
def _compute_precision_recall(
|
28 |
+
self, binary_pred_mask: torch.Tensor, gt_mask: torch.Tensor
|
29 |
+
) -> torch.Tensor:
|
30 |
+
"""
|
31 |
+
:param binary_pred_mask: (B x H x W) or (H x W)
|
32 |
+
:param gt_mask: (B x H x W) or (H x W), should be the same with binary_pred_mask
|
33 |
+
"""
|
34 |
+
tp = torch.logical_and(binary_pred_mask, gt_mask).sum(dim=(-1, -2))
|
35 |
+
tp_fp = binary_pred_mask.sum(dim=(-1, -2))
|
36 |
+
tp_fn = gt_mask.sum(dim=(-1, -2))
|
37 |
+
|
38 |
+
prec = tp / (tp_fp + self.eps)
|
39 |
+
recall = tp / (tp_fn + self.eps)
|
40 |
+
return prec, recall
|
41 |
+
|
42 |
+
def _compute_f_measure(
|
43 |
+
self,
|
44 |
+
pred_mask: torch.Tensor,
|
45 |
+
gt_mask: torch.Tensor,
|
46 |
+
thresholds: torch.Tensor = None,
|
47 |
+
) -> torch.Tensor:
|
48 |
+
if thresholds is None:
|
49 |
+
binary_pred_mask = pred_mask > self.default_thres
|
50 |
+
else:
|
51 |
+
binary_pred_mask = pred_mask > thresholds
|
52 |
+
|
53 |
+
prec, recall = self._compute_precision_recall(binary_pred_mask, gt_mask)
|
54 |
+
f_measure = ((1 + (self.beta_square**2)) * prec * recall) / (
|
55 |
+
(self.beta_square**2) * prec + recall + self.eps
|
56 |
+
)
|
57 |
+
return f_measure.cpu()
|
58 |
+
|
59 |
+
def _compute_f_max(
|
60 |
+
self, pred_mask: torch.Tensor, gt_mask: torch.Tensor
|
61 |
+
) -> torch.Tensor:
|
62 |
+
"""Compute self.n_bins + 1 F-measures, each of which has a different threshold, then return the maximum
|
63 |
+
F-measure among them.
|
64 |
+
|
65 |
+
:param pred_mask: (H x W)
|
66 |
+
:param gt_mask: (H x W)
|
67 |
+
"""
|
68 |
+
|
69 |
+
# pred_masks, gt_masks: H x W -> self.n_bins x H x W
|
70 |
+
pred_masks = pred_mask.unsqueeze(dim=0).repeat(self.n_bins, 1, 1)
|
71 |
+
gt_masks = gt_mask.unsqueeze(dim=0).repeat(self.n_bins, 1, 1)
|
72 |
+
|
73 |
+
# thresholds: self.n_bins x 1 x 1
|
74 |
+
thresholds = (
|
75 |
+
torch.arange(0, 1, 1 / self.n_bins)
|
76 |
+
.view(self.n_bins, 1, 1)
|
77 |
+
.to(pred_masks.device)
|
78 |
+
)
|
79 |
+
|
80 |
+
# f_measures: self.n_bins
|
81 |
+
f_measures = self._compute_f_measure(pred_masks, gt_masks, thresholds)
|
82 |
+
return torch.max(f_measures).cpu(), f_measures
|
83 |
+
|
84 |
+
def _compute_f_mean(
|
85 |
+
self,
|
86 |
+
pred_mask: torch.Tensor,
|
87 |
+
gt_mask: torch.Tensor,
|
88 |
+
) -> torch.Tensor:
|
89 |
+
adaptive_thres = 2 * pred_mask.mean(dim=(-1, -2), keepdim=True)
|
90 |
+
binary_pred_mask = pred_mask > adaptive_thres
|
91 |
+
|
92 |
+
prec, recall = self._compute_precision_recall(binary_pred_mask, gt_mask)
|
93 |
+
f_mean = ((1 + (self.beta_square**2)) * prec * recall) / (
|
94 |
+
(self.beta_square**2) * prec + recall + self.eps
|
95 |
+
)
|
96 |
+
return f_mean.cpu()
|
97 |
+
|
98 |
+
def __call__(self, pred_mask: torch.Tensor, gt_mask: torch.Tensor) -> dict:
|
99 |
+
"""
|
100 |
+
:param pred_mask: (H x W) a normalized prediction mask with values in [0, 1]
|
101 |
+
:param gt_mask: (H x W) a binary ground truth mask with values in {0, 1}
|
102 |
+
:return: a dictionary with keys being "f_measure" and "f_max" and values being the respective values.
|
103 |
+
"""
|
104 |
+
outputs: dict = dict()
|
105 |
+
for k in ("f_measure", "f_mean"):
|
106 |
+
outputs.update({k: getattr(self, f"_compute_{k}")(pred_mask, gt_mask)})
|
107 |
+
|
108 |
+
f_max_, all_f = self._compute_f_max(pred_mask, gt_mask)
|
109 |
+
outputs["f_max"] = f_max_
|
110 |
+
outputs["all_f"] = all_f # List of all f values for all thresholds
|
111 |
+
return outputs
|
evaluation/metrics/iou.py
ADDED
@@ -0,0 +1,37 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Code adapted from SelfMask: https://github.com/NoelShin/selfmask
|
3 |
+
"""
|
4 |
+
|
5 |
+
from typing import Optional, Union
|
6 |
+
|
7 |
+
import numpy as np
|
8 |
+
import torch
|
9 |
+
|
10 |
+
|
11 |
+
def compute_iou(
|
12 |
+
pred_mask: Union[np.ndarray, torch.Tensor],
|
13 |
+
gt_mask: Union[np.ndarray, torch.Tensor],
|
14 |
+
threshold: Optional[float] = 0.5,
|
15 |
+
eps: float = 1e-7,
|
16 |
+
) -> Union[np.ndarray, torch.Tensor]:
|
17 |
+
"""
|
18 |
+
:param pred_mask: (B x H x W) or (H x W)
|
19 |
+
:param gt_mask: (B x H x W) or (H x W), same shape with pred_mask
|
20 |
+
:param threshold: a binarization threshold
|
21 |
+
:param eps: a small value for computational stability
|
22 |
+
:return: (B) or (1)
|
23 |
+
"""
|
24 |
+
assert pred_mask.shape == gt_mask.shape, f"{pred_mask.shape} != {gt_mask.shape}"
|
25 |
+
# assert 0. <= pred_mask.to(torch.float32).min() and pred_mask.max().to(torch.float32) <= 1., f"{pred_mask.min(), pred_mask.max()}"
|
26 |
+
|
27 |
+
if threshold is not None:
|
28 |
+
pred_mask = pred_mask > threshold
|
29 |
+
if isinstance(pred_mask, np.ndarray):
|
30 |
+
intersection = np.logical_and(pred_mask, gt_mask).sum(axis=(-1, -2))
|
31 |
+
union = np.logical_or(pred_mask, gt_mask).sum(axis=(-1, -2))
|
32 |
+
ious = intersection / (union + eps)
|
33 |
+
else:
|
34 |
+
intersection = torch.logical_and(pred_mask, gt_mask).sum(dim=(-1, -2))
|
35 |
+
union = torch.logical_or(pred_mask, gt_mask).sum(dim=(-1, -2))
|
36 |
+
ious = (intersection / (union + eps)).cpu()
|
37 |
+
return ious
|
evaluation/metrics/mae.py
ADDED
@@ -0,0 +1,14 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Code borrowed from SelfMask: https://github.com/NoelShin/selfmask
|
3 |
+
"""
|
4 |
+
|
5 |
+
import torch
|
6 |
+
|
7 |
+
def compute_mae(pred_mask: torch.Tensor, gt_mask: torch.Tensor) -> torch.Tensor:
|
8 |
+
"""
|
9 |
+
:param pred_mask: (H x W) or (B x H x W) a normalized prediction mask with values in [0, 1]
|
10 |
+
:param gt_mask: (H x W) or (B x H x W) a binary ground truth mask with values in {0, 1}
|
11 |
+
"""
|
12 |
+
return torch.mean(
|
13 |
+
torch.abs(pred_mask - gt_mask.to(torch.float32)), dim=(-1, -2)
|
14 |
+
).cpu()
|
evaluation/metrics/pixel_acc.py
ADDED
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Code borrowed from SelfMask: https://github.com/NoelShin/selfmask
|
3 |
+
"""
|
4 |
+
|
5 |
+
from typing import Optional
|
6 |
+
|
7 |
+
import torch
|
8 |
+
|
9 |
+
|
10 |
+
def compute_pixel_accuracy(
|
11 |
+
pred_mask: torch.Tensor, gt_mask: torch.Tensor, threshold: Optional[float] = 0.5
|
12 |
+
) -> torch.Tensor:
|
13 |
+
"""
|
14 |
+
:param pred_mask: (H x W) or (B x H x W) a normalized prediction mask with values in [0, 1]
|
15 |
+
:param gt_mask: (H x W) or (B x H x W) a binary ground truth mask with values in {0, 1}
|
16 |
+
"""
|
17 |
+
if threshold is not None:
|
18 |
+
binary_pred_mask = pred_mask > threshold
|
19 |
+
else:
|
20 |
+
binary_pred_mask = pred_mask
|
21 |
+
return (binary_pred_mask == gt_mask).to(torch.float32).mean(dim=(-1, -2)).cpu()
|
evaluation/metrics/s_measure.py
ADDED
@@ -0,0 +1,126 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# code borrowed from https://github.com/Hanqer/Evaluate-SOD/blob/master/evaluator.py
|
2 |
+
import numpy as np
|
3 |
+
import torch
|
4 |
+
|
5 |
+
|
6 |
+
class SMeasure:
|
7 |
+
def __init__(self, alpha: float = 0.5):
|
8 |
+
self.alpha: float = alpha
|
9 |
+
self.cuda: bool = True
|
10 |
+
|
11 |
+
def _centroid(self, gt):
|
12 |
+
rows, cols = gt.size()[-2:]
|
13 |
+
gt = gt.view(rows, cols)
|
14 |
+
if gt.sum() == 0:
|
15 |
+
if self.cuda:
|
16 |
+
X = torch.eye(1).cuda() * round(cols / 2)
|
17 |
+
Y = torch.eye(1).cuda() * round(rows / 2)
|
18 |
+
else:
|
19 |
+
X = torch.eye(1) * round(cols / 2)
|
20 |
+
Y = torch.eye(1) * round(rows / 2)
|
21 |
+
else:
|
22 |
+
total = gt.sum()
|
23 |
+
if self.cuda:
|
24 |
+
i = torch.from_numpy(np.arange(0, cols)).cuda().float()
|
25 |
+
j = torch.from_numpy(np.arange(0, rows)).cuda().float()
|
26 |
+
else:
|
27 |
+
i = torch.from_numpy(np.arange(0, cols)).float()
|
28 |
+
j = torch.from_numpy(np.arange(0, rows)).float()
|
29 |
+
X = torch.round((gt.sum(dim=0) * i).sum() / total)
|
30 |
+
Y = torch.round((gt.sum(dim=1) * j).sum() / total)
|
31 |
+
return X.long(), Y.long()
|
32 |
+
|
33 |
+
def _ssim(self, pred, gt):
|
34 |
+
gt = gt.float()
|
35 |
+
h, w = pred.size()[-2:]
|
36 |
+
N = h * w
|
37 |
+
x = pred.mean()
|
38 |
+
y = gt.mean()
|
39 |
+
sigma_x2 = ((pred - x) * (pred - x)).sum() / (N - 1 + 1e-20)
|
40 |
+
sigma_y2 = ((gt - y) * (gt - y)).sum() / (N - 1 + 1e-20)
|
41 |
+
sigma_xy = ((pred - x) * (gt - y)).sum() / (N - 1 + 1e-20)
|
42 |
+
|
43 |
+
aplha = 4 * x * y * sigma_xy
|
44 |
+
beta = (x * x + y * y) * (sigma_x2 + sigma_y2)
|
45 |
+
|
46 |
+
if aplha != 0:
|
47 |
+
Q = aplha / (beta + 1e-20)
|
48 |
+
elif aplha == 0 and beta == 0:
|
49 |
+
Q = 1.0
|
50 |
+
else:
|
51 |
+
Q = 0
|
52 |
+
return Q
|
53 |
+
|
54 |
+
def _object(self, pred, gt):
|
55 |
+
temp = pred[gt == 1]
|
56 |
+
x = temp.mean()
|
57 |
+
sigma_x = temp.std()
|
58 |
+
score = 2.0 * x / (x * x + 1.0 + sigma_x + 1e-20)
|
59 |
+
|
60 |
+
return score
|
61 |
+
|
62 |
+
def _s_object(self, pred, gt):
|
63 |
+
fg = torch.where(gt == 0, torch.zeros_like(pred), pred)
|
64 |
+
bg = torch.where(gt == 1, torch.zeros_like(pred), 1 - pred)
|
65 |
+
o_fg = self._object(fg, gt)
|
66 |
+
o_bg = self._object(bg, 1 - gt)
|
67 |
+
u = gt.mean()
|
68 |
+
Q = u * o_fg + (1 - u) * o_bg
|
69 |
+
return Q
|
70 |
+
|
71 |
+
def _divide_gt(self, gt, X, Y):
|
72 |
+
h, w = gt.size()[-2:]
|
73 |
+
area = h * w
|
74 |
+
gt = gt.view(h, w)
|
75 |
+
LT = gt[:Y, :X]
|
76 |
+
RT = gt[:Y, X:w]
|
77 |
+
LB = gt[Y:h, :X]
|
78 |
+
RB = gt[Y:h, X:w]
|
79 |
+
X = X.float()
|
80 |
+
Y = Y.float()
|
81 |
+
w1 = X * Y / area
|
82 |
+
w2 = (w - X) * Y / area
|
83 |
+
w3 = X * (h - Y) / area
|
84 |
+
w4 = 1 - w1 - w2 - w3
|
85 |
+
return LT, RT, LB, RB, w1, w2, w3, w4
|
86 |
+
|
87 |
+
def _divide_prediction(self, pred, X, Y):
|
88 |
+
h, w = pred.size()[-2:]
|
89 |
+
pred = pred.view(h, w)
|
90 |
+
LT = pred[:Y, :X]
|
91 |
+
RT = pred[:Y, X:w]
|
92 |
+
LB = pred[Y:h, :X]
|
93 |
+
RB = pred[Y:h, X:w]
|
94 |
+
return LT, RT, LB, RB
|
95 |
+
|
96 |
+
def _s_region(self, pred, gt):
|
97 |
+
X, Y = self._centroid(gt)
|
98 |
+
gt1, gt2, gt3, gt4, w1, w2, w3, w4 = self._divide_gt(gt, X, Y)
|
99 |
+
p1, p2, p3, p4 = self._divide_prediction(pred, X, Y)
|
100 |
+
Q1 = self._ssim(p1, gt1)
|
101 |
+
Q2 = self._ssim(p2, gt2)
|
102 |
+
Q3 = self._ssim(p3, gt3)
|
103 |
+
Q4 = self._ssim(p4, gt4)
|
104 |
+
Q = w1 * Q1 + w2 * Q2 + w3 * Q3 + w4 * Q4
|
105 |
+
# print(Q)
|
106 |
+
return Q
|
107 |
+
|
108 |
+
def __call__(self, pred_mask: torch.Tensor, gt_mask: torch.Tensor):
|
109 |
+
assert pred_mask.shape == gt_mask.shape
|
110 |
+
y = gt_mask.mean()
|
111 |
+
if y == 0:
|
112 |
+
x = pred_mask.mean()
|
113 |
+
Q = 1.0 - x
|
114 |
+
elif y == 1:
|
115 |
+
x = pred_mask.mean()
|
116 |
+
Q = x
|
117 |
+
else:
|
118 |
+
gt_mask[gt_mask >= 0.5] = 1
|
119 |
+
gt_mask[gt_mask < 0.5] = 0
|
120 |
+
# print(self._S_object(pred, gt), self._S_region(pred, gt))
|
121 |
+
Q = self.alpha * self._s_object(pred_mask, gt_mask) + (
|
122 |
+
1 - self.alpha
|
123 |
+
) * self._s_region(pred_mask, gt_mask)
|
124 |
+
if Q.item() < 0:
|
125 |
+
Q = torch.FloatTensor([0.0])
|
126 |
+
return Q.item()
|
evaluation/saliency.py
ADDED
@@ -0,0 +1,290 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2022 - Valeo Comfort and Driving Assistance - valeo.ai
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
|
15 |
+
import torch
|
16 |
+
import numpy as np
|
17 |
+
import torch.nn as nn
|
18 |
+
import torch.nn.functional as F
|
19 |
+
|
20 |
+
from tqdm import tqdm
|
21 |
+
from scipy import ndimage
|
22 |
+
|
23 |
+
from evaluation.metrics.average_meter import AverageMeter
|
24 |
+
from evaluation.metrics.f_measure import FMeasure
|
25 |
+
from evaluation.metrics.iou import compute_iou
|
26 |
+
from evaluation.metrics.mae import compute_mae
|
27 |
+
from evaluation.metrics.pixel_acc import compute_pixel_accuracy
|
28 |
+
from evaluation.metrics.s_measure import SMeasure
|
29 |
+
|
30 |
+
from misc import batch_apply_bilateral_solver
|
31 |
+
|
32 |
+
|
33 |
+
@torch.no_grad()
|
34 |
+
def write_metric_tf(
|
35 |
+
writer,
|
36 |
+
metrics,
|
37 |
+
n_iter = -1,
|
38 |
+
name = ""
|
39 |
+
):
|
40 |
+
writer.add_scalar(
|
41 |
+
f"Validation/{name}iou_pred",
|
42 |
+
metrics["ious"].avg,
|
43 |
+
n_iter,
|
44 |
+
)
|
45 |
+
writer.add_scalar(
|
46 |
+
f"Validation/{name}acc_pred",
|
47 |
+
metrics["pixel_accs"].avg,
|
48 |
+
n_iter,
|
49 |
+
)
|
50 |
+
writer.add_scalar(
|
51 |
+
f"Validation/{name}f_max",
|
52 |
+
metrics["f_maxs"].avg,
|
53 |
+
n_iter,
|
54 |
+
)
|
55 |
+
|
56 |
+
@torch.no_grad()
|
57 |
+
def eval_batch(
|
58 |
+
batch_gt_masks,
|
59 |
+
batch_pred_masks,
|
60 |
+
metrics_res={},
|
61 |
+
reset=False
|
62 |
+
):
|
63 |
+
"""
|
64 |
+
Evaluation code adapted from SelfMask: https://github.com/NoelShin/selfmask
|
65 |
+
"""
|
66 |
+
|
67 |
+
f_values = {}
|
68 |
+
# Keep track of f_values for each threshold
|
69 |
+
for i in range(255): # should equal n_bins in metrics/f_measure.py
|
70 |
+
f_values[i] = AverageMeter()
|
71 |
+
|
72 |
+
if metrics_res == {}:
|
73 |
+
metrics_res["f_scores"] = AverageMeter()
|
74 |
+
metrics_res["f_maxs"] = AverageMeter()
|
75 |
+
metrics_res["f_maxs_fixed"] = AverageMeter()
|
76 |
+
metrics_res["f_means"] = AverageMeter()
|
77 |
+
metrics_res["maes"] = AverageMeter()
|
78 |
+
metrics_res["ious"] = AverageMeter()
|
79 |
+
metrics_res["pixel_accs"] = AverageMeter()
|
80 |
+
metrics_res["s_measures"] = AverageMeter()
|
81 |
+
|
82 |
+
if reset:
|
83 |
+
metrics_res["f_scores"].reset()
|
84 |
+
metrics_res["f_maxs"].reset()
|
85 |
+
metrics_res["f_maxs_fixed"].reset()
|
86 |
+
metrics_res["f_means"].reset()
|
87 |
+
metrics_res["maes"].reset()
|
88 |
+
metrics_res["ious"].reset()
|
89 |
+
metrics_res["pixel_accs"].reset()
|
90 |
+
metrics_res["s_measures"].reset()
|
91 |
+
|
92 |
+
# iterate over batch dimension
|
93 |
+
for _, (pred_mask, gt_mask) in enumerate(
|
94 |
+
zip(batch_pred_masks, batch_gt_masks)
|
95 |
+
):
|
96 |
+
assert pred_mask.shape == gt_mask.shape, f"{pred_mask.shape} != {gt_mask.shape}"
|
97 |
+
assert len(pred_mask.shape) == len(gt_mask.shape) == 2
|
98 |
+
# Compute
|
99 |
+
# Binarize at 0.5 for IoU and pixel accuracy
|
100 |
+
binary_pred = (pred_mask > 0.5).float().squeeze()
|
101 |
+
iou = compute_iou(binary_pred, gt_mask)
|
102 |
+
f_measures = FMeasure()(pred_mask, gt_mask) # soft mask for F measure
|
103 |
+
mae = compute_mae(binary_pred, gt_mask)
|
104 |
+
pixel_acc = compute_pixel_accuracy(binary_pred, gt_mask)
|
105 |
+
|
106 |
+
# Update
|
107 |
+
metrics_res["ious"].update(val=iou.numpy(), n=1)
|
108 |
+
metrics_res["f_scores"].update(val=f_measures["f_measure"].numpy(), n=1)
|
109 |
+
metrics_res["f_maxs"].update(val=f_measures["f_max"].numpy(), n=1)
|
110 |
+
metrics_res["f_means"].update(val=f_measures["f_mean"].numpy(), n=1)
|
111 |
+
metrics_res["s_measures"].update(
|
112 |
+
val=SMeasure()(pred_mask=pred_mask, gt_mask=gt_mask.to(torch.float32)), n=1
|
113 |
+
)
|
114 |
+
metrics_res["maes"].update(val=mae.numpy(), n=1)
|
115 |
+
metrics_res["pixel_accs"].update(val=pixel_acc.numpy(), n=1)
|
116 |
+
|
117 |
+
# Keep track of f_values for each threshold
|
118 |
+
all_f = f_measures["all_f"].numpy()
|
119 |
+
for k, v in f_values.items():
|
120 |
+
v.update(val=all_f[k], n=1)
|
121 |
+
# Then compute the max for the f_max_fixed
|
122 |
+
metrics_res["f_maxs_fixed"].update(
|
123 |
+
val=np.max([v.avg for v in f_values.values()]), n=1
|
124 |
+
)
|
125 |
+
|
126 |
+
results = {}
|
127 |
+
# F-measure, F-max, F-mean, MAE, S-measure, IoU, pixel acc.
|
128 |
+
results["f_measure"] = metrics_res["f_scores"].avg
|
129 |
+
results["f_max"] = metrics_res["f_maxs"].avg
|
130 |
+
results["f_maxs_fixed"] = metrics_res["f_maxs_fixed"].avg
|
131 |
+
results["f_mean"] = metrics_res["f_means"].avg
|
132 |
+
results["s_measure"] = metrics_res["s_measures"].avg
|
133 |
+
results["mae"] = metrics_res["maes"].avg
|
134 |
+
results["iou"] = float(iou.numpy())
|
135 |
+
results["pixel_acc"] = metrics_res["pixel_accs"].avg
|
136 |
+
|
137 |
+
return results, metrics_res
|
138 |
+
|
139 |
+
def evaluate_saliency(
|
140 |
+
dataset,
|
141 |
+
model,
|
142 |
+
writer=None,
|
143 |
+
batch_size=1,
|
144 |
+
n_iter=-1,
|
145 |
+
apply_bilateral=False,
|
146 |
+
im_fullsize=True,
|
147 |
+
method="pred", # can also be "bkg",
|
148 |
+
apply_weights: bool = True,
|
149 |
+
evaluation_mode: str = 'single', # choices are ["single", "multi"]
|
150 |
+
):
|
151 |
+
|
152 |
+
if im_fullsize:
|
153 |
+
# Change transformation
|
154 |
+
dataset.fullimg_mode()
|
155 |
+
batch_size = 1
|
156 |
+
|
157 |
+
valloader = torch.utils.data.DataLoader(
|
158 |
+
dataset,
|
159 |
+
batch_size=batch_size,
|
160 |
+
shuffle=False,
|
161 |
+
num_workers=2
|
162 |
+
)
|
163 |
+
|
164 |
+
sigmoid = nn.Sigmoid()
|
165 |
+
|
166 |
+
metrics_res = {}
|
167 |
+
metrics_res_bs = {}
|
168 |
+
valbar = tqdm(enumerate(valloader, 0), leave=None)
|
169 |
+
for i, data in valbar:
|
170 |
+
inputs, _, gt_labels, _ = data
|
171 |
+
inputs = inputs.to("cuda")
|
172 |
+
gt_labels = gt_labels.to("cuda").float()
|
173 |
+
|
174 |
+
# Forward step
|
175 |
+
with torch.no_grad():
|
176 |
+
preds, _, shape_f, att = model.forward_step(inputs, for_eval=True)
|
177 |
+
|
178 |
+
if method == "pred":
|
179 |
+
h, w = gt_labels.shape[-2:]
|
180 |
+
preds_up = F.interpolate(
|
181 |
+
preds, scale_factor=model.vit_patch_size, mode="bilinear", align_corners=False
|
182 |
+
)[..., :h, :w]
|
183 |
+
soft_preds = sigmoid(preds_up.detach()).squeeze(0)
|
184 |
+
preds_up = (
|
185 |
+
(sigmoid(preds_up.detach()) > 0.5).squeeze(0).float()
|
186 |
+
)
|
187 |
+
|
188 |
+
elif method == "bkg":
|
189 |
+
bkg_mask_pred = model.compute_background_batch(
|
190 |
+
att, shape_f,
|
191 |
+
apply_weights=apply_weights,
|
192 |
+
)
|
193 |
+
# Transform bkg detection to foreground detection
|
194 |
+
obj_mask = (
|
195 |
+
~bkg_mask_pred.bool()
|
196 |
+
).float() # Obj labels is inverse of bkg
|
197 |
+
|
198 |
+
# Fit predictions to image size
|
199 |
+
preds_up = F.interpolate(
|
200 |
+
obj_mask.unsqueeze(1),
|
201 |
+
gt_labels.shape[-2:],
|
202 |
+
mode="bilinear",
|
203 |
+
align_corners=False,
|
204 |
+
)
|
205 |
+
preds_up = (preds_up > 0.5).float()
|
206 |
+
soft_preds = preds_up # not soft actually
|
207 |
+
|
208 |
+
reset = True if i == 0 else False
|
209 |
+
if evaluation_mode == 'single':
|
210 |
+
labeled, nr_objects = ndimage.label(preds_up.squeeze().cpu().numpy())
|
211 |
+
if nr_objects == 0:
|
212 |
+
preds_up_one_cc = preds_up.squeeze()
|
213 |
+
print("nr_objects == 0")
|
214 |
+
else:
|
215 |
+
nb_pixel = [np.sum(labeled == i) for i in range(nr_objects + 1)]
|
216 |
+
pixel_order = np.argsort(nb_pixel)
|
217 |
+
|
218 |
+
cc = [torch.Tensor(labeled == i) for i in pixel_order]
|
219 |
+
cc = torch.stack(cc).cuda()
|
220 |
+
|
221 |
+
# Find CC set as background, here not necessarily the biggest
|
222 |
+
cc_background = (
|
223 |
+
(
|
224 |
+
(
|
225 |
+
(~(preds_up[None, :, :, :].bool())).float()
|
226 |
+
+ cc[:, None, :, :].cuda()
|
227 |
+
)
|
228 |
+
> 1
|
229 |
+
).sum(-1).sum(-1).argmax()
|
230 |
+
)
|
231 |
+
pixel_order = np.delete(
|
232 |
+
pixel_order, int(cc_background.cpu().numpy())
|
233 |
+
)
|
234 |
+
|
235 |
+
preds_up_one_cc = torch.Tensor(labeled == pixel_order[-1]).cuda()
|
236 |
+
|
237 |
+
_, metrics_res = eval_batch(
|
238 |
+
gt_labels,
|
239 |
+
preds_up_one_cc.unsqueeze(0),
|
240 |
+
metrics_res=metrics_res,
|
241 |
+
reset=reset,
|
242 |
+
)
|
243 |
+
|
244 |
+
if writer is not None:
|
245 |
+
write_metric_tf(writer, metrics_res, n_iter=n_iter, name=f"_{evaluation_mode}_")
|
246 |
+
|
247 |
+
elif evaluation_mode == 'multi':
|
248 |
+
# Eval without bilateral solver
|
249 |
+
_, metrics_res = eval_batch(
|
250 |
+
gt_labels,
|
251 |
+
soft_preds.unsqueeze(0) if len(soft_preds.shape) == 2 else soft_preds,
|
252 |
+
metrics_res=metrics_res,
|
253 |
+
reset=reset,
|
254 |
+
) # soft preds needed for F beta measure
|
255 |
+
|
256 |
+
# Apply bilateral solver
|
257 |
+
preds_bs = None
|
258 |
+
if apply_bilateral:
|
259 |
+
get_all_cc = True if evaluation_mode == 'multi' else False
|
260 |
+
preds_bs, _ = batch_apply_bilateral_solver(data,
|
261 |
+
preds_up.detach(),
|
262 |
+
get_all_cc = get_all_cc
|
263 |
+
)
|
264 |
+
|
265 |
+
_, metrics_res_bs = eval_batch(
|
266 |
+
gt_labels,
|
267 |
+
preds_bs[None,:,:].float(),
|
268 |
+
metrics_res=metrics_res_bs,
|
269 |
+
reset=reset
|
270 |
+
)
|
271 |
+
|
272 |
+
if writer is not None:
|
273 |
+
write_metric_tf(writer, metrics_res_bs, n_iter=n_iter, name=f"_{evaluation_mode}-BS_")
|
274 |
+
|
275 |
+
bar_str = f"{dataset.name} | {evaluation_mode} mode | " \
|
276 |
+
f"F-max {metrics_res['f_maxs'].avg:.3f} " \
|
277 |
+
f"IoU {metrics_res['ious'].avg:.3f}, " \
|
278 |
+
f"PA {metrics_res['pixel_accs'].avg:.3f}"
|
279 |
+
|
280 |
+
if apply_bilateral:
|
281 |
+
bar_str += f" | with bilateral solver: " \
|
282 |
+
f"F-max {metrics_res_bs['f_maxs'].avg:.3f}, " \
|
283 |
+
f"IoU {metrics_res_bs['ious'].avg:.3f}, " \
|
284 |
+
f"PA. {metrics_res_bs['pixel_accs'].avg:.3f}"
|
285 |
+
|
286 |
+
valbar.set_description(bar_str)
|
287 |
+
|
288 |
+
# Go back to original transformation
|
289 |
+
if im_fullsize:
|
290 |
+
dataset.training_mode()
|
evaluation/uod.py
ADDED
@@ -0,0 +1,118 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2021 - Valeo Comfort and Driving Assistance - Oriane Siméoni @ valeo.ai
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
|
15 |
+
"""
|
16 |
+
Code adapted from previous method LOST: https://github.com/valeoai/LOST
|
17 |
+
"""
|
18 |
+
|
19 |
+
import os
|
20 |
+
import time
|
21 |
+
import torch
|
22 |
+
import torch.nn as nn
|
23 |
+
import numpy as np
|
24 |
+
|
25 |
+
from tqdm import tqdm
|
26 |
+
from misc import bbox_iou, get_bbox_from_segmentation_labels
|
27 |
+
|
28 |
+
|
29 |
+
def evaluation_unsupervised_object_discovery(
|
30 |
+
dataset,
|
31 |
+
model,
|
32 |
+
evaluation_mode: str = 'single', # choices are ["single", "multi"]
|
33 |
+
output_dir:str = "outputs",
|
34 |
+
no_hards:bool = False,
|
35 |
+
):
|
36 |
+
|
37 |
+
assert evaluation_mode == "single"
|
38 |
+
|
39 |
+
sigmoid = nn.Sigmoid()
|
40 |
+
|
41 |
+
# ----------------------------------------------------
|
42 |
+
# Loop over images
|
43 |
+
preds_dict = {}
|
44 |
+
cnt = 0
|
45 |
+
corloc = np.zeros(len(dataset.dataloader))
|
46 |
+
|
47 |
+
start_time = time.time()
|
48 |
+
pbar = tqdm(dataset.dataloader)
|
49 |
+
for im_id, inp in enumerate(pbar):
|
50 |
+
|
51 |
+
# ------------ IMAGE PROCESSING -------------------------------------------
|
52 |
+
img = inp[0]
|
53 |
+
|
54 |
+
init_image_size = img.shape
|
55 |
+
|
56 |
+
# Get the name of the image
|
57 |
+
im_name = dataset.get_image_name(inp[1])
|
58 |
+
# Pass in case of no gt boxes in the image
|
59 |
+
if im_name is None:
|
60 |
+
continue
|
61 |
+
|
62 |
+
# Padding the image with zeros to fit multiple of patch-size
|
63 |
+
size_im = (
|
64 |
+
img.shape[0],
|
65 |
+
int(np.ceil(img.shape[1] / model.vit_patch_size) * model.vit_patch_size),
|
66 |
+
int(np.ceil(img.shape[2] / model.vit_patch_size) * model.vit_patch_size),
|
67 |
+
)
|
68 |
+
paded = torch.zeros(size_im)
|
69 |
+
paded[:, : img.shape[1], : img.shape[2]] = img
|
70 |
+
img = paded
|
71 |
+
|
72 |
+
# # Move to gpu
|
73 |
+
img = img.cuda(non_blocking=True)
|
74 |
+
|
75 |
+
# Size for transformers
|
76 |
+
w_featmap = img.shape[-2] // model.vit_patch_size
|
77 |
+
h_featmap = img.shape[-1] // model.vit_patch_size
|
78 |
+
|
79 |
+
# ------------ GROUND-TRUTH -------------------------------------------
|
80 |
+
gt_bbxs, gt_cls = dataset.extract_gt(inp[1], im_name)
|
81 |
+
|
82 |
+
if gt_bbxs is not None:
|
83 |
+
# Discard images with no gt annotations
|
84 |
+
# Happens only in the case of VOC07 and VOC12
|
85 |
+
if gt_bbxs.shape[0] == 0 and no_hards:
|
86 |
+
continue
|
87 |
+
|
88 |
+
outputs = model.forward_step(img[None, :, :, :])
|
89 |
+
preds = (sigmoid(outputs[0].detach()) > 0.5).float().squeeze().cpu().numpy()
|
90 |
+
|
91 |
+
# get bbox
|
92 |
+
pred = get_bbox_from_segmentation_labels(
|
93 |
+
segmenter_predictions=preds,
|
94 |
+
scales=[model.vit_patch_size, model.vit_patch_size],
|
95 |
+
initial_image_size=init_image_size[1:],
|
96 |
+
)
|
97 |
+
|
98 |
+
# ------------ Visualizations -------------------------------------------
|
99 |
+
# Save the prediction
|
100 |
+
preds_dict[im_name] = pred
|
101 |
+
|
102 |
+
|
103 |
+
# Compare prediction to GT boxes
|
104 |
+
ious = bbox_iou(torch.from_numpy(pred), torch.from_numpy(gt_bbxs))
|
105 |
+
|
106 |
+
if torch.any(ious >= 0.5):
|
107 |
+
corloc[im_id] = 1
|
108 |
+
|
109 |
+
cnt += 1
|
110 |
+
if cnt % 50 == 0:
|
111 |
+
pbar.set_description(f"Found {int(np.sum(corloc))}/{cnt}")
|
112 |
+
|
113 |
+
# Evaluate
|
114 |
+
print(f"corloc: {100*np.sum(corloc)/cnt:.2f} ({int(np.sum(corloc))}/{cnt})")
|
115 |
+
result_file = os.path.join(output_dir, 'uod_results.txt')
|
116 |
+
with open(result_file, 'w') as f:
|
117 |
+
f.write('corloc,%.1f,,\n'%(100*np.sum(corloc)/cnt))
|
118 |
+
print('File saved at %s'%result_file)
|
main_found_evaluate.py
ADDED
@@ -0,0 +1,122 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2022 - Valeo Comfort and Driving Assistance - Oriane Siméoni @ valeo.ai
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
|
15 |
+
import argparse
|
16 |
+
from model import FoundModel
|
17 |
+
from misc import load_config
|
18 |
+
from datasets.datasets import build_dataset
|
19 |
+
from evaluation.saliency import evaluate_saliency
|
20 |
+
from evaluation.uod import evaluation_unsupervised_object_discovery
|
21 |
+
|
22 |
+
if __name__ == "__main__":
|
23 |
+
parser = argparse.ArgumentParser(
|
24 |
+
description = 'Evaluation of FOUND',
|
25 |
+
formatter_class=argparse.ArgumentDefaultsHelpFormatter
|
26 |
+
)
|
27 |
+
parser.add_argument(
|
28 |
+
"--eval-type",
|
29 |
+
type=str,
|
30 |
+
choices=["saliency", "uod"],
|
31 |
+
help="Evaluation type."
|
32 |
+
)
|
33 |
+
parser.add_argument(
|
34 |
+
"--dataset-eval",
|
35 |
+
type=str,
|
36 |
+
choices=["ECSSD", "DUT-OMRON", "DUTS-TEST", "VOC07", "VOC12", "COCO20k"],
|
37 |
+
help="Name of evaluation dataset."
|
38 |
+
)
|
39 |
+
parser.add_argument(
|
40 |
+
"--dataset-set-eval",
|
41 |
+
type=str,
|
42 |
+
default=None,
|
43 |
+
help="Set of the dataset."
|
44 |
+
)
|
45 |
+
parser.add_argument(
|
46 |
+
"--apply-bilateral",
|
47 |
+
action="store_true",
|
48 |
+
help="use bilateral solver."
|
49 |
+
)
|
50 |
+
parser.add_argument(
|
51 |
+
"--evaluation-mode",
|
52 |
+
type=str,
|
53 |
+
default="multi",
|
54 |
+
choices=["single", "multi"],
|
55 |
+
help="Type of evaluation."
|
56 |
+
)
|
57 |
+
parser.add_argument(
|
58 |
+
"--model-weights",
|
59 |
+
type=str,
|
60 |
+
default="data/weights/decoder_weights.pt",
|
61 |
+
)
|
62 |
+
parser.add_argument(
|
63 |
+
"--dataset-dir",
|
64 |
+
type=str,
|
65 |
+
default="/datasets_local",
|
66 |
+
)
|
67 |
+
parser.add_argument(
|
68 |
+
"--config",
|
69 |
+
type=str,
|
70 |
+
default="configs/found_DUTS-TR.yaml",
|
71 |
+
)
|
72 |
+
args = parser.parse_args()
|
73 |
+
print(args.__dict__)
|
74 |
+
|
75 |
+
# Configuration
|
76 |
+
config = load_config(args.config)
|
77 |
+
|
78 |
+
# ------------------------------------
|
79 |
+
# Load the model
|
80 |
+
model = FoundModel(vit_model=config.model["pre_training"],
|
81 |
+
vit_arch=config.model["arch"],
|
82 |
+
vit_patch_size=config.model["patch_size"],
|
83 |
+
enc_type_feats=config.found["feats"],
|
84 |
+
bkg_type_feats=config.found["feats"],
|
85 |
+
bkg_th=config.found["bkg_th"])
|
86 |
+
# Load weights
|
87 |
+
model.decoder_load_weights(args.model_weights)
|
88 |
+
model.eval()
|
89 |
+
print(f"Model {args.model_weights} loaded correctly.")
|
90 |
+
|
91 |
+
# ------------------------------------
|
92 |
+
# Build the validation set
|
93 |
+
val_dataset = build_dataset(
|
94 |
+
root_dir=args.dataset_dir,
|
95 |
+
dataset_name=args.dataset_eval,
|
96 |
+
dataset_set=args.dataset_set_eval,
|
97 |
+
for_eval=True,
|
98 |
+
evaluation_type=args.eval_type,
|
99 |
+
)
|
100 |
+
print(f"\nBuilding dataset {val_dataset.name} (#{len(val_dataset)} images)")
|
101 |
+
|
102 |
+
# ------------------------------------
|
103 |
+
# Training
|
104 |
+
print(f"\nStarted evaluation on {val_dataset.name}")
|
105 |
+
if args.eval_type == "saliency":
|
106 |
+
evaluate_saliency(
|
107 |
+
val_dataset,
|
108 |
+
model=model,
|
109 |
+
evaluation_mode=args.evaluation_mode,
|
110 |
+
apply_bilateral=args.apply_bilateral,
|
111 |
+
)
|
112 |
+
elif args.eval_type == "uod":
|
113 |
+
if args.apply_bilateral:
|
114 |
+
raise ValueError("Not implemented.")
|
115 |
+
|
116 |
+
evaluation_unsupervised_object_discovery(
|
117 |
+
val_dataset,
|
118 |
+
model=model,
|
119 |
+
evaluation_mode=args.evaluation_mode,
|
120 |
+
)
|
121 |
+
else:
|
122 |
+
raise ValueError("Other evaluation method to come.")
|
main_visualize.py
ADDED
@@ -0,0 +1,99 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2022 - Valeo Comfort and Driving Assistance - Oriane Siméoni @ valeo.ai
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
|
15 |
+
import os
|
16 |
+
import torch
|
17 |
+
import argparse
|
18 |
+
import torch.nn as nn
|
19 |
+
import torch.nn.functional as F
|
20 |
+
import matplotlib.pyplot as plt
|
21 |
+
|
22 |
+
from PIL import Image
|
23 |
+
from model import FoundModel
|
24 |
+
from misc import load_config
|
25 |
+
from torchvision import transforms as T
|
26 |
+
|
27 |
+
NORMALIZE = T.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))
|
28 |
+
|
29 |
+
if __name__ == "__main__":
|
30 |
+
parser = argparse.ArgumentParser(
|
31 |
+
description = 'Evaluation of FOUND',
|
32 |
+
formatter_class=argparse.ArgumentDefaultsHelpFormatter
|
33 |
+
)
|
34 |
+
|
35 |
+
parser.add_argument(
|
36 |
+
"--img-path", type=str, default="data/examples/VOC07_000007.jpg", help="Image path."
|
37 |
+
)
|
38 |
+
parser.add_argument(
|
39 |
+
"--model-weights", type=str, default="data/weights/decoder_weights.pt",
|
40 |
+
)
|
41 |
+
parser.add_argument(
|
42 |
+
"--config", type=str, default="configs/found_DUTS-TR.yaml",
|
43 |
+
)
|
44 |
+
parser.add_argument(
|
45 |
+
"--output-dir", type=str, default="outputs",
|
46 |
+
)
|
47 |
+
args = parser.parse_args()
|
48 |
+
|
49 |
+
# Saving dir
|
50 |
+
if not os.path.exists(args.output_dir):
|
51 |
+
os.makedirs(args.output_dir)
|
52 |
+
|
53 |
+
# Configuration
|
54 |
+
config = load_config(args.config)
|
55 |
+
|
56 |
+
# ------------------------------------
|
57 |
+
# Load the model
|
58 |
+
model = FoundModel(vit_model=config.model["pre_training"],
|
59 |
+
vit_arch=config.model["arch"],
|
60 |
+
vit_patch_size=config.model["patch_size"],
|
61 |
+
enc_type_feats=config.found["feats"],
|
62 |
+
bkg_type_feats=config.found["feats"],
|
63 |
+
bkg_th=config.found["bkg_th"])
|
64 |
+
# Load weights
|
65 |
+
model.decoder_load_weights(args.model_weights)
|
66 |
+
model.eval()
|
67 |
+
print(f"Model {args.model_weights} loaded correctly.")
|
68 |
+
|
69 |
+
# Load the image
|
70 |
+
with open(args.img_path, "rb") as f:
|
71 |
+
img = Image.open(f)
|
72 |
+
img = img.convert("RGB")
|
73 |
+
|
74 |
+
t = T.Compose([T.ToTensor(), NORMALIZE])
|
75 |
+
img_t = t(img)[None,:,:,:]
|
76 |
+
inputs = img_t.to("cuda")
|
77 |
+
|
78 |
+
# Forward step
|
79 |
+
with torch.no_grad():
|
80 |
+
preds, _, shape_f, att = model.forward_step(inputs, for_eval=True)
|
81 |
+
|
82 |
+
# Apply FOUND
|
83 |
+
sigmoid = nn.Sigmoid()
|
84 |
+
h, w = img_t.shape[-2:]
|
85 |
+
preds_up = F.interpolate(
|
86 |
+
preds, scale_factor=model.vit_patch_size, mode="bilinear", align_corners=False
|
87 |
+
)[..., :h, :w]
|
88 |
+
preds_up = (
|
89 |
+
(sigmoid(preds_up.detach()) > 0.5).squeeze(0).float()
|
90 |
+
)
|
91 |
+
|
92 |
+
plt.figure()
|
93 |
+
plt.imshow(img)
|
94 |
+
plt.imshow(preds_up.cpu().squeeze().numpy(), 'gray', interpolation='none', alpha=0.5)
|
95 |
+
plt.axis('off')
|
96 |
+
img_name = args.img_path
|
97 |
+
img_name = img_name.split('/')[-1].split('.')[0]
|
98 |
+
plt.savefig(os.path.join(args.output_dir, f'{img_name}-found.png'), bbox_inches='tight', pad_inches=0)
|
99 |
+
plt.close()
|
misc.py
ADDED
@@ -0,0 +1,254 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import re
|
2 |
+
import os
|
3 |
+
import cv2
|
4 |
+
import yaml
|
5 |
+
import math
|
6 |
+
import random
|
7 |
+
import scipy.ndimage
|
8 |
+
import numpy as np
|
9 |
+
|
10 |
+
import torch
|
11 |
+
import torch.nn.functional as F
|
12 |
+
|
13 |
+
from typing import List
|
14 |
+
from torchvision import transforms as T
|
15 |
+
|
16 |
+
from bilateral_solver import bilateral_solver_output
|
17 |
+
|
18 |
+
|
19 |
+
loader = yaml.SafeLoader
|
20 |
+
loader.add_implicit_resolver(
|
21 |
+
u'tag:yaml.org,2002:float',
|
22 |
+
re.compile(u'''^(?:
|
23 |
+
[-+]?(?:[0-9][0-9_]*)\\.[0-9_]*(?:[eE][-+]?[0-9]+)?
|
24 |
+
|[-+]?(?:[0-9][0-9_]*)(?:[eE][-+]?[0-9]+)
|
25 |
+
|\\.[0-9_]+(?:[eE][-+][0-9]+)?
|
26 |
+
|[-+]?[0-9][0-9_]*(?::[0-5]?[0-9])+\\.[0-9_]*
|
27 |
+
|[-+]?\\.(?:inf|Inf|INF)
|
28 |
+
|\\.(?:nan|NaN|NAN))$''', re.X),
|
29 |
+
list(u'-+0123456789.'))
|
30 |
+
|
31 |
+
class Struct:
|
32 |
+
def __init__(self, **entries):
|
33 |
+
self.__dict__.update(entries)
|
34 |
+
|
35 |
+
def load_config(config_file):
|
36 |
+
with open(config_file, errors='ignore') as f:
|
37 |
+
# conf = yaml.safe_load(f) # load config
|
38 |
+
conf = yaml.load(f, Loader=loader)
|
39 |
+
print('hyperparameters: ' + ', '.join(f'{k}={v}' for k, v in conf.items()))
|
40 |
+
|
41 |
+
#TODO yaml_save(save_dir / 'config.yaml', conf)
|
42 |
+
return Struct(**conf)
|
43 |
+
|
44 |
+
def set_seed(seed: int) -> None:
|
45 |
+
"""
|
46 |
+
Set all seeds to make results reproducible
|
47 |
+
"""
|
48 |
+
# env
|
49 |
+
os.environ["PYTHONHASHSEED"] = str(seed)
|
50 |
+
|
51 |
+
# python
|
52 |
+
random.seed(seed)
|
53 |
+
|
54 |
+
# numpy
|
55 |
+
np.random.seed(seed)
|
56 |
+
|
57 |
+
# torch
|
58 |
+
torch.manual_seed(seed)
|
59 |
+
torch.cuda.manual_seed_all(seed)
|
60 |
+
torch.backends.cudnn.deterministic = True
|
61 |
+
|
62 |
+
def IoU(mask1, mask2):
|
63 |
+
"""
|
64 |
+
Code adapted from TokenCut: https://github.com/YangtaoWANG95/TokenCut
|
65 |
+
"""
|
66 |
+
mask1, mask2 = (mask1 > 0.5).to(torch.bool), (mask2 > 0.5).to(torch.bool)
|
67 |
+
intersection = torch.sum(mask1 * (mask1 == mask2), dim=[-1, -2]).squeeze()
|
68 |
+
union = torch.sum(mask1 + mask2, dim=[-1, -2]).squeeze()
|
69 |
+
return (intersection.to(torch.float) / union).mean().item()
|
70 |
+
|
71 |
+
def batch_apply_bilateral_solver(data,
|
72 |
+
masks,
|
73 |
+
get_all_cc=True,
|
74 |
+
shape=None):
|
75 |
+
|
76 |
+
cnt_bs = 0
|
77 |
+
masks_bs = []
|
78 |
+
inputs, init_imgs, gt_labels, img_path = data
|
79 |
+
|
80 |
+
for id in range(inputs.shape[0]):
|
81 |
+
_, bs_mask, use_bs = apply_bilateral_solver(
|
82 |
+
mask=masks[id].squeeze().cpu().numpy(),
|
83 |
+
img=init_imgs[id],
|
84 |
+
img_path=img_path[id],
|
85 |
+
im_fullsize=False,
|
86 |
+
# Careful shape should be opposed
|
87 |
+
shape=(gt_labels.shape[-1], gt_labels.shape[-2]),
|
88 |
+
get_all_cc=get_all_cc,
|
89 |
+
)
|
90 |
+
cnt_bs += use_bs
|
91 |
+
|
92 |
+
# use the bilateral solver output if IoU > 0.5
|
93 |
+
if use_bs:
|
94 |
+
if shape is None:
|
95 |
+
shape = masks.shape[-2:]
|
96 |
+
# Interpolate to downsample the mask back
|
97 |
+
bs_ds = F.interpolate(
|
98 |
+
torch.Tensor(bs_mask).unsqueeze(0).unsqueeze(0),
|
99 |
+
shape, # TODO check here
|
100 |
+
mode="bilinear",
|
101 |
+
align_corners=False,
|
102 |
+
)
|
103 |
+
masks_bs.append(bs_ds.bool().cuda().squeeze()[None, :, :])
|
104 |
+
else:
|
105 |
+
# Use initial mask
|
106 |
+
masks_bs.append(masks[id].cuda().squeeze()[None, :, :])
|
107 |
+
|
108 |
+
return torch.cat(masks_bs).squeeze(), cnt_bs
|
109 |
+
|
110 |
+
|
111 |
+
def apply_bilateral_solver(
|
112 |
+
mask,
|
113 |
+
img,
|
114 |
+
img_path,
|
115 |
+
shape,
|
116 |
+
im_fullsize=False,
|
117 |
+
get_all_cc=False,
|
118 |
+
bs_iou_threshold: float = 0.5,
|
119 |
+
reshape: bool = True,
|
120 |
+
):
|
121 |
+
# Get initial image in the case of using full image
|
122 |
+
img_init = None
|
123 |
+
if not im_fullsize:
|
124 |
+
# Use the image given by dataloader
|
125 |
+
shape = (img.shape[-1], img.shape[-2])
|
126 |
+
t = T.ToPILImage()
|
127 |
+
img_init = t(img)
|
128 |
+
|
129 |
+
if reshape:
|
130 |
+
# Resize predictions to image size
|
131 |
+
resized_mask = cv2.resize(mask, shape)
|
132 |
+
sel_obj_mask = resized_mask
|
133 |
+
else:
|
134 |
+
resized_mask = mask
|
135 |
+
sel_obj_mask = mask
|
136 |
+
|
137 |
+
# Apply bilinear solver
|
138 |
+
_, binary_solver = bilateral_solver_output(
|
139 |
+
img_path,
|
140 |
+
resized_mask,
|
141 |
+
img=img_init,
|
142 |
+
sigma_spatial=16,
|
143 |
+
sigma_luma=16,
|
144 |
+
sigma_chroma=8,
|
145 |
+
get_all_cc=get_all_cc,
|
146 |
+
)
|
147 |
+
|
148 |
+
mask1 = torch.from_numpy(resized_mask).cuda()
|
149 |
+
mask2 = torch.from_numpy(binary_solver).cuda().float()
|
150 |
+
|
151 |
+
use_bs = 0
|
152 |
+
# If enough overlap, use BS output
|
153 |
+
if IoU(mask1, mask2) > bs_iou_threshold:
|
154 |
+
sel_obj_mask = binary_solver.astype(float)
|
155 |
+
use_bs = 1
|
156 |
+
|
157 |
+
return resized_mask, sel_obj_mask, use_bs
|
158 |
+
|
159 |
+
def get_bbox_from_segmentation_labels(
|
160 |
+
segmenter_predictions: torch.Tensor,
|
161 |
+
initial_image_size: torch.Size,
|
162 |
+
scales: List[int],
|
163 |
+
) -> np.array:
|
164 |
+
"""
|
165 |
+
Find the largest connected component in foreground, extract its bounding box
|
166 |
+
"""
|
167 |
+
objects, num_objects = scipy.ndimage.label(segmenter_predictions)
|
168 |
+
|
169 |
+
# find biggest connected component
|
170 |
+
all_foreground_labels = objects.flatten()[objects.flatten() != 0]
|
171 |
+
most_frequent_label = np.bincount(all_foreground_labels).argmax()
|
172 |
+
mask = np.where(objects == most_frequent_label)
|
173 |
+
# Add +1 because excluded max
|
174 |
+
ymin, ymax = min(mask[0]), max(mask[0]) + 1
|
175 |
+
xmin, xmax = min(mask[1]), max(mask[1]) + 1
|
176 |
+
|
177 |
+
if initial_image_size == segmenter_predictions.shape:
|
178 |
+
# Masks are already upsampled
|
179 |
+
pred = [xmin, ymin, xmax, ymax]
|
180 |
+
else:
|
181 |
+
# Rescale to image size
|
182 |
+
r_xmin, r_xmax = scales[1] * xmin, scales[1] * xmax
|
183 |
+
r_ymin, r_ymax = scales[0] * ymin, scales[0] * ymax
|
184 |
+
pred = [r_xmin, r_ymin, r_xmax, r_ymax]
|
185 |
+
|
186 |
+
# Check not out of image size (used when padding)
|
187 |
+
if initial_image_size:
|
188 |
+
pred[2] = min(pred[2], initial_image_size[1])
|
189 |
+
pred[3] = min(pred[3], initial_image_size[0])
|
190 |
+
|
191 |
+
return np.asarray(pred)
|
192 |
+
|
193 |
+
|
194 |
+
def bbox_iou(
|
195 |
+
box1: np.array,
|
196 |
+
box2: np.array,
|
197 |
+
x1y1x2y2: bool = True,
|
198 |
+
GIoU: bool = False,
|
199 |
+
DIoU: bool = False,
|
200 |
+
CIoU: bool = False,
|
201 |
+
eps: float = 1e-7,
|
202 |
+
):
|
203 |
+
# https://github.com/ultralytics/yolov5/blob/develop/utils/general.py
|
204 |
+
# Returns the IoU of box1 to box2. box1 is 4, box2 is nx4
|
205 |
+
box2 = box2.T
|
206 |
+
|
207 |
+
# Get the coordinates of bounding boxes
|
208 |
+
if x1y1x2y2: # x1, y1, x2, y2 = box1
|
209 |
+
b1_x1, b1_y1, b1_x2, b1_y2 = box1[0], box1[1], box1[2], box1[3]
|
210 |
+
b2_x1, b2_y1, b2_x2, b2_y2 = box2[0], box2[1], box2[2], box2[3]
|
211 |
+
else: # transform from xywh to xyxy
|
212 |
+
b1_x1, b1_x2 = box1[0] - box1[2] / 2, box1[0] + box1[2] / 2
|
213 |
+
b1_y1, b1_y2 = box1[1] - box1[3] / 2, box1[1] + box1[3] / 2
|
214 |
+
b2_x1, b2_x2 = box2[0] - box2[2] / 2, box2[0] + box2[2] / 2
|
215 |
+
b2_y1, b2_y2 = box2[1] - box2[3] / 2, box2[1] + box2[3] / 2
|
216 |
+
|
217 |
+
# Intersection area
|
218 |
+
inter = (torch.min(b1_x2, b2_x2) - torch.max(b1_x1, b2_x1)).clamp(0) * (
|
219 |
+
torch.min(b1_y2, b2_y2) - torch.max(b1_y1, b2_y1)
|
220 |
+
).clamp(0)
|
221 |
+
|
222 |
+
# Union Area
|
223 |
+
w1, h1 = b1_x2 - b1_x1, b1_y2 - b1_y1 + eps
|
224 |
+
w2, h2 = b2_x2 - b2_x1, b2_y2 - b2_y1 + eps
|
225 |
+
union = w1 * h1 + w2 * h2 - inter + eps
|
226 |
+
|
227 |
+
iou = inter / union
|
228 |
+
if GIoU or DIoU or CIoU:
|
229 |
+
cw = torch.max(b1_x2, b2_x2) - torch.min(
|
230 |
+
b1_x1, b2_x1
|
231 |
+
) # convex (smallest enclosing box) width
|
232 |
+
ch = torch.max(b1_y2, b2_y2) - torch.min(b1_y1, b2_y1) # convex height
|
233 |
+
if CIoU or DIoU: # Distance or Complete IoU https://arxiv.org/abs/1911.08287v1
|
234 |
+
c2 = cw**2 + ch**2 + eps # convex diagonal squared
|
235 |
+
rho2 = (
|
236 |
+
(b2_x1 + b2_x2 - b1_x1 - b1_x2) ** 2
|
237 |
+
+ (b2_y1 + b2_y2 - b1_y1 - b1_y2) ** 2
|
238 |
+
) / 4 # center distance squared
|
239 |
+
if DIoU:
|
240 |
+
return iou - rho2 / c2 # DIoU
|
241 |
+
elif (
|
242 |
+
CIoU
|
243 |
+
): # https://github.com/Zzh-tju/DIoU-SSD-pytorch/blob/master/utils/box/box_utils.py#L47
|
244 |
+
v = (4 / math.pi**2) * torch.pow(
|
245 |
+
torch.atan(w2 / h2) - torch.atan(w1 / h1), 2
|
246 |
+
)
|
247 |
+
with torch.no_grad():
|
248 |
+
alpha = v / (v - iou + (1 + eps))
|
249 |
+
return iou - (rho2 / c2 + v * alpha) # CIoU
|
250 |
+
else: # GIoU https://arxiv.org/pdf/1902.09630.pdf
|
251 |
+
c_area = cw * ch + eps # convex area
|
252 |
+
return iou - (c_area - union) / c_area # GIoU
|
253 |
+
else:
|
254 |
+
return iou # IoU
|
model.py
ADDED
@@ -0,0 +1,243 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2022 - Valeo Comfort and Driving Assistance - Oriane Siméoni @ valeo.ai
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
|
15 |
+
|
16 |
+
import os
|
17 |
+
import torch
|
18 |
+
import torch.nn as nn
|
19 |
+
import dino.vision_transformer as vits
|
20 |
+
|
21 |
+
from bkg_seg import compute_img_bkg_seg
|
22 |
+
from misc import batch_apply_bilateral_solver
|
23 |
+
|
24 |
+
class FoundModel(nn.Module):
|
25 |
+
def __init__(
|
26 |
+
self,
|
27 |
+
vit_model="dino",
|
28 |
+
vit_arch="vit_small",
|
29 |
+
vit_patch_size=8,
|
30 |
+
enc_type_feats="k",
|
31 |
+
bkg_type_feats="k",
|
32 |
+
bkg_th=0.3
|
33 |
+
):
|
34 |
+
|
35 |
+
super(FoundModel, self).__init__()
|
36 |
+
|
37 |
+
# ----------------------
|
38 |
+
# Encoder
|
39 |
+
self.vit_encoder, self.initial_dim, self.hook_features = get_vit_encoder(
|
40 |
+
vit_arch, vit_model, vit_patch_size, enc_type_feats
|
41 |
+
)
|
42 |
+
self.vit_patch_size = vit_patch_size
|
43 |
+
self.enc_type_feats = enc_type_feats
|
44 |
+
|
45 |
+
# ----------------------
|
46 |
+
# Background Segmentation
|
47 |
+
self.bkg_type_feats = bkg_type_feats
|
48 |
+
self.bkg_th = bkg_th
|
49 |
+
|
50 |
+
# ----------------------
|
51 |
+
# Define the simple decoder
|
52 |
+
self.previous_dim = self.initial_dim
|
53 |
+
self.decoder = nn.Conv2d(self.previous_dim, 1, (1, 1))
|
54 |
+
|
55 |
+
def forward_step(self, batch, decoder=None, for_eval=False):
|
56 |
+
|
57 |
+
# Make the image divisible by the patch size
|
58 |
+
if for_eval:
|
59 |
+
batch = self.make_input_divisible(batch)
|
60 |
+
_w, _h = batch.shape[-2:]
|
61 |
+
_h, _w = _h // self.vit_patch_size, _w // self.vit_patch_size
|
62 |
+
else:
|
63 |
+
# Cropping used during training, could be changed to improve
|
64 |
+
w, h = (
|
65 |
+
batch.shape[-2] - batch.shape[-2] % self.vit_patch_size,
|
66 |
+
batch.shape[-1] - batch.shape[-1] % self.vit_patch_size,
|
67 |
+
)
|
68 |
+
batch = batch[:, :, :w, :h]
|
69 |
+
|
70 |
+
w_featmap = batch.shape[-2] // self.vit_patch_size
|
71 |
+
h_featmap = batch.shape[-1] // self.vit_patch_size
|
72 |
+
|
73 |
+
# Forward pass
|
74 |
+
with torch.no_grad():
|
75 |
+
# Encoder forward pass
|
76 |
+
att = self.vit_encoder.get_last_selfattention(batch)
|
77 |
+
|
78 |
+
# Get decoder features
|
79 |
+
feats = self.extract_feats(dims=att.shape, type_feats=self.enc_type_feats)
|
80 |
+
feats = feats[:, 1:, :, :].reshape(att.shape[0], w_featmap, h_featmap, -1)
|
81 |
+
feats = feats.permute(0, 3, 1, 2)
|
82 |
+
|
83 |
+
# Apply decoder
|
84 |
+
if decoder is None:
|
85 |
+
decoder = self.decoder
|
86 |
+
preds = decoder(feats)
|
87 |
+
|
88 |
+
# return preds_masked
|
89 |
+
return preds, feats, (w_featmap, h_featmap), att
|
90 |
+
|
91 |
+
def make_input_divisible(self, x: torch.Tensor) -> torch.Tensor:
|
92 |
+
# From selfmask
|
93 |
+
"""Pad some pixels to make the input size divisible by the patch size."""
|
94 |
+
B, _, H_0, W_0 = x.shape
|
95 |
+
pad_w = (self.vit_patch_size - W_0 % self.vit_patch_size) % self.vit_patch_size
|
96 |
+
pad_h = (self.vit_patch_size - H_0 % self.vit_patch_size) % self.vit_patch_size
|
97 |
+
|
98 |
+
x = nn.functional.pad(x, (0, pad_w, 0, pad_h), value=0)
|
99 |
+
return x
|
100 |
+
|
101 |
+
def compute_background_batch(
|
102 |
+
self,
|
103 |
+
att,
|
104 |
+
shape_f,
|
105 |
+
# mlp_feats = None,
|
106 |
+
):
|
107 |
+
|
108 |
+
w_f, h_f = shape_f
|
109 |
+
|
110 |
+
# Dimensions
|
111 |
+
nb_im = att.shape[0] # Batch size
|
112 |
+
nh = att.shape[1] # Number of heads
|
113 |
+
nb_tokens = att.shape[2] # Number of tokens
|
114 |
+
|
115 |
+
# Get decoder features
|
116 |
+
feats = self.extract_feats(dims=att.shape,
|
117 |
+
# mlp_feats = mlp_feats,
|
118 |
+
type_feats=self.bkg_type_feats
|
119 |
+
)
|
120 |
+
feats = feats.reshape(nb_im, nb_tokens, -1)
|
121 |
+
|
122 |
+
bkg_mask = compute_img_bkg_seg(
|
123 |
+
att,
|
124 |
+
feats,
|
125 |
+
(w_f,h_f),
|
126 |
+
th_bkg=self.bkg_th,
|
127 |
+
dim=int(self.initial_dim / nh),
|
128 |
+
)
|
129 |
+
|
130 |
+
return bkg_mask
|
131 |
+
|
132 |
+
|
133 |
+
def get_bkg_pseudo_labels_batch(
|
134 |
+
self,
|
135 |
+
att,
|
136 |
+
shape_f,
|
137 |
+
data,
|
138 |
+
use_bilateral_solver = True,
|
139 |
+
shape=None,
|
140 |
+
):
|
141 |
+
|
142 |
+
bkg_mask_pred = self.compute_background_batch(
|
143 |
+
att, shape_f
|
144 |
+
)
|
145 |
+
# Transform bkg detection to foreground detection
|
146 |
+
# Object mask is the inverse of the bkg mask
|
147 |
+
obj_mask = (~bkg_mask_pred.bool()).float()
|
148 |
+
|
149 |
+
if use_bilateral_solver:
|
150 |
+
pseudo_labels, cnt_bs = batch_apply_bilateral_solver(data, obj_mask, shape)
|
151 |
+
return pseudo_labels, cnt_bs
|
152 |
+
else:
|
153 |
+
return obj_mask, 0
|
154 |
+
|
155 |
+
@torch.no_grad()
|
156 |
+
def decoder_load_weights(self, weights_path):
|
157 |
+
print(f"Loading model from weights {weights_path}.")
|
158 |
+
# Load states
|
159 |
+
state_dict = torch.load(weights_path)
|
160 |
+
|
161 |
+
# Decoder
|
162 |
+
self.decoder.load_state_dict(state_dict["decoder"])
|
163 |
+
self.decoder.eval()
|
164 |
+
self.decoder.to("cuda")
|
165 |
+
|
166 |
+
|
167 |
+
@torch.no_grad()
|
168 |
+
def decoder_save_weights(self, save_dir, n_iter):
|
169 |
+
state_dict = {}
|
170 |
+
state_dict["decoder"] = self.decoder.state_dict()
|
171 |
+
fname = os.path.join(
|
172 |
+
save_dir, f"decoder_weights_niter{n_iter}.pt"
|
173 |
+
)
|
174 |
+
torch.save(state_dict, fname)
|
175 |
+
print(f"\n----"
|
176 |
+
f"\nModel saved at {fname}"
|
177 |
+
)
|
178 |
+
|
179 |
+
@torch.no_grad()
|
180 |
+
def extract_feats(self, dims, type_feats="k"):
|
181 |
+
|
182 |
+
nb_im, nh, nb_tokens, _ = dims
|
183 |
+
qkv = (
|
184 |
+
self.hook_features["qkv"]
|
185 |
+
.reshape(
|
186 |
+
nb_im, nb_tokens, 3, nh, -1 // nh
|
187 |
+
) # 3 corresponding to |qkv|
|
188 |
+
.permute(2, 0, 3, 1, 4)
|
189 |
+
)
|
190 |
+
|
191 |
+
q, k, v = qkv[0], qkv[1], qkv[2]
|
192 |
+
|
193 |
+
if type_feats == "q":
|
194 |
+
return q.transpose(1, 2).float()
|
195 |
+
elif type_feats == "k":
|
196 |
+
return k.transpose(1, 2).float()
|
197 |
+
elif type_feats == "v":
|
198 |
+
return v.transpose(1, 2).float()
|
199 |
+
else:
|
200 |
+
raise ValueError("Unknown features")
|
201 |
+
|
202 |
+
|
203 |
+
def get_vit_encoder(vit_arch, vit_model, vit_patch_size, enc_type_feats):
|
204 |
+
if vit_arch == "vit_small" and vit_patch_size == 16:
|
205 |
+
url = "dino_deitsmall16_pretrain/dino_deitsmall16_pretrain.pth"
|
206 |
+
initial_dim = 384
|
207 |
+
elif vit_arch == "vit_small" and vit_patch_size == 8:
|
208 |
+
url = "dino_deitsmall8_300ep_pretrain/dino_deitsmall8_300ep_pretrain.pth"
|
209 |
+
initial_dim = 384
|
210 |
+
elif vit_arch == "vit_base" and vit_patch_size == 16:
|
211 |
+
if vit_model == "clip":
|
212 |
+
url = "5806e77cd80f8b59890b7e101eabd078d9fb84e6937f9e85e4ecb61988df416f/ViT-B-16.pt"
|
213 |
+
elif vit_model == "dino":
|
214 |
+
url = "dino_vitbase16_pretrain/dino_vitbase16_pretrain.pth"
|
215 |
+
initial_dim = 768
|
216 |
+
elif vit_arch == "vit_base" and vit_patch_size == 8:
|
217 |
+
url = "dino_vitbase8_pretrain/dino_vitbase8_pretrain.pth"
|
218 |
+
initial_dim = 768
|
219 |
+
|
220 |
+
if vit_model == "dino":
|
221 |
+
vit_encoder = vits.__dict__[vit_arch](patch_size=vit_patch_size, num_classes=0)
|
222 |
+
# TODO change if want to have last layer not unfrozen
|
223 |
+
for p in vit_encoder.parameters():
|
224 |
+
p.requires_grad = False
|
225 |
+
vit_encoder.eval().cuda() # mode eval
|
226 |
+
state_dict = torch.hub.load_state_dict_from_url(
|
227 |
+
url="https://dl.fbaipublicfiles.com/dino/" + url
|
228 |
+
)
|
229 |
+
vit_encoder.load_state_dict(state_dict, strict=True)
|
230 |
+
|
231 |
+
hook_features = {}
|
232 |
+
if enc_type_feats in ["k", "q", "v", "qkv", "mlp"]:
|
233 |
+
# Define the hook
|
234 |
+
def hook_fn_forward_qkv(module, input, output):
|
235 |
+
hook_features["qkv"] = output
|
236 |
+
|
237 |
+
vit_encoder._modules["blocks"][-1]._modules["attn"]._modules[
|
238 |
+
"qkv"
|
239 |
+
].register_forward_hook(hook_fn_forward_qkv)
|
240 |
+
else:
|
241 |
+
raise ValueError("Not implemented.")
|
242 |
+
|
243 |
+
return vit_encoder, initial_dim, hook_features
|
requirements.txt
ADDED
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
pyyaml
|
2 |
+
matplotlib==3.5.2
|
3 |
+
numpy==1.21.4
|
4 |
+
opencv-python==4.5.5.64
|
5 |
+
opencv-python-headless==4.5.5.64
|
6 |
+
scipy==1.7.3
|
7 |
+
tensorboard
|
8 |
+
tqdm==4.64.0
|
9 |
+
pycocotools==2.0.4
|
10 |
+
Pillow==9.1.1
|