Spaces:
Runtime error
Runtime error
yael-vinker
commited on
Commit
•
3c149ed
1
Parent(s):
253b0de
- CLIPasso-local-demo.ipynb +0 -0
- CLIPasso.ipynb +0 -0
- LICENSE +437 -0
- U2Net_/model/__init__.py +2 -0
- U2Net_/model/u2net.py +525 -0
- U2Net_/model/u2net_refactor.py +168 -0
- U2Net_/saved_models/face_detection_cv2/haarcascade_frontalface_default.xml +0 -0
- cog.yaml +54 -0
- config.py +144 -0
- display_results.py +94 -0
- models/loss.py +463 -0
- models/painter_params.py +539 -0
- painterly_rendering.py +190 -0
- predict.py +317 -0
- requirements.txt +53 -0
- run_object_sketching.py +158 -0
- sketch_utils.py +295 -0
- target_images/camel.png +0 -0
- target_images/flamingo.png +0 -0
- target_images/horse.png +0 -0
- target_images/rose.jpeg +0 -0
CLIPasso-local-demo.ipynb
ADDED
The diff for this file is too large to render.
See raw diff
|
|
CLIPasso.ipynb
ADDED
The diff for this file is too large to render.
See raw diff
|
|
LICENSE
ADDED
@@ -0,0 +1,437 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
Attribution-NonCommercial-ShareAlike 4.0 International
|
2 |
+
|
3 |
+
=======================================================================
|
4 |
+
|
5 |
+
Creative Commons Corporation ("Creative Commons") is not a law firm and
|
6 |
+
does not provide legal services or legal advice. Distribution of
|
7 |
+
Creative Commons public licenses does not create a lawyer-client or
|
8 |
+
other relationship. Creative Commons makes its licenses and related
|
9 |
+
information available on an "as-is" basis. Creative Commons gives no
|
10 |
+
warranties regarding its licenses, any material licensed under their
|
11 |
+
terms and conditions, or any related information. Creative Commons
|
12 |
+
disclaims all liability for damages resulting from their use to the
|
13 |
+
fullest extent possible.
|
14 |
+
|
15 |
+
Using Creative Commons Public Licenses
|
16 |
+
|
17 |
+
Creative Commons public licenses provide a standard set of terms and
|
18 |
+
conditions that creators and other rights holders may use to share
|
19 |
+
original works of authorship and other material subject to copyright
|
20 |
+
and certain other rights specified in the public license below. The
|
21 |
+
following considerations are for informational purposes only, are not
|
22 |
+
exhaustive, and do not form part of our licenses.
|
23 |
+
|
24 |
+
Considerations for licensors: Our public licenses are
|
25 |
+
intended for use by those authorized to give the public
|
26 |
+
permission to use material in ways otherwise restricted by
|
27 |
+
copyright and certain other rights. Our licenses are
|
28 |
+
irrevocable. Licensors should read and understand the terms
|
29 |
+
and conditions of the license they choose before applying it.
|
30 |
+
Licensors should also secure all rights necessary before
|
31 |
+
applying our licenses so that the public can reuse the
|
32 |
+
material as expected. Licensors should clearly mark any
|
33 |
+
material not subject to the license. This includes other CC-
|
34 |
+
licensed material, or material used under an exception or
|
35 |
+
limitation to copyright. More considerations for licensors:
|
36 |
+
wiki.creativecommons.org/Considerations_for_licensors
|
37 |
+
|
38 |
+
Considerations for the public: By using one of our public
|
39 |
+
licenses, a licensor grants the public permission to use the
|
40 |
+
licensed material under specified terms and conditions. If
|
41 |
+
the licensor's permission is not necessary for any reason--for
|
42 |
+
example, because of any applicable exception or limitation to
|
43 |
+
copyright--then that use is not regulated by the license. Our
|
44 |
+
licenses grant only permissions under copyright and certain
|
45 |
+
other rights that a licensor has authority to grant. Use of
|
46 |
+
the licensed material may still be restricted for other
|
47 |
+
reasons, including because others have copyright or other
|
48 |
+
rights in the material. A licensor may make special requests,
|
49 |
+
such as asking that all changes be marked or described.
|
50 |
+
Although not required by our licenses, you are encouraged to
|
51 |
+
respect those requests where reasonable. More considerations
|
52 |
+
for the public:
|
53 |
+
wiki.creativecommons.org/Considerations_for_licensees
|
54 |
+
|
55 |
+
=======================================================================
|
56 |
+
|
57 |
+
Creative Commons Attribution-NonCommercial-ShareAlike 4.0 International
|
58 |
+
Public License
|
59 |
+
|
60 |
+
By exercising the Licensed Rights (defined below), You accept and agree
|
61 |
+
to be bound by the terms and conditions of this Creative Commons
|
62 |
+
Attribution-NonCommercial-ShareAlike 4.0 International Public License
|
63 |
+
("Public License"). To the extent this Public License may be
|
64 |
+
interpreted as a contract, You are granted the Licensed Rights in
|
65 |
+
consideration of Your acceptance of these terms and conditions, and the
|
66 |
+
Licensor grants You such rights in consideration of benefits the
|
67 |
+
Licensor receives from making the Licensed Material available under
|
68 |
+
these terms and conditions.
|
69 |
+
|
70 |
+
|
71 |
+
Section 1 -- Definitions.
|
72 |
+
|
73 |
+
a. Adapted Material means material subject to Copyright and Similar
|
74 |
+
Rights that is derived from or based upon the Licensed Material
|
75 |
+
and in which the Licensed Material is translated, altered,
|
76 |
+
arranged, transformed, or otherwise modified in a manner requiring
|
77 |
+
permission under the Copyright and Similar Rights held by the
|
78 |
+
Licensor. For purposes of this Public License, where the Licensed
|
79 |
+
Material is a musical work, performance, or sound recording,
|
80 |
+
Adapted Material is always produced where the Licensed Material is
|
81 |
+
synched in timed relation with a moving image.
|
82 |
+
|
83 |
+
b. Adapter's License means the license You apply to Your Copyright
|
84 |
+
and Similar Rights in Your contributions to Adapted Material in
|
85 |
+
accordance with the terms and conditions of this Public License.
|
86 |
+
|
87 |
+
c. BY-NC-SA Compatible License means a license listed at
|
88 |
+
creativecommons.org/compatiblelicenses, approved by Creative
|
89 |
+
Commons as essentially the equivalent of this Public License.
|
90 |
+
|
91 |
+
d. Copyright and Similar Rights means copyright and/or similar rights
|
92 |
+
closely related to copyright including, without limitation,
|
93 |
+
performance, broadcast, sound recording, and Sui Generis Database
|
94 |
+
Rights, without regard to how the rights are labeled or
|
95 |
+
categorized. For purposes of this Public License, the rights
|
96 |
+
specified in Section 2(b)(1)-(2) are not Copyright and Similar
|
97 |
+
Rights.
|
98 |
+
|
99 |
+
e. Effective Technological Measures means those measures that, in the
|
100 |
+
absence of proper authority, may not be circumvented under laws
|
101 |
+
fulfilling obligations under Article 11 of the WIPO Copyright
|
102 |
+
Treaty adopted on December 20, 1996, and/or similar international
|
103 |
+
agreements.
|
104 |
+
|
105 |
+
f. Exceptions and Limitations means fair use, fair dealing, and/or
|
106 |
+
any other exception or limitation to Copyright and Similar Rights
|
107 |
+
that applies to Your use of the Licensed Material.
|
108 |
+
|
109 |
+
g. License Elements means the license attributes listed in the name
|
110 |
+
of a Creative Commons Public License. The License Elements of this
|
111 |
+
Public License are Attribution, NonCommercial, and ShareAlike.
|
112 |
+
|
113 |
+
h. Licensed Material means the artistic or literary work, database,
|
114 |
+
or other material to which the Licensor applied this Public
|
115 |
+
License.
|
116 |
+
|
117 |
+
i. Licensed Rights means the rights granted to You subject to the
|
118 |
+
terms and conditions of this Public License, which are limited to
|
119 |
+
all Copyright and Similar Rights that apply to Your use of the
|
120 |
+
Licensed Material and that the Licensor has authority to license.
|
121 |
+
|
122 |
+
j. Licensor means the individual(s) or entity(ies) granting rights
|
123 |
+
under this Public License.
|
124 |
+
|
125 |
+
k. NonCommercial means not primarily intended for or directed towards
|
126 |
+
commercial advantage or monetary compensation. For purposes of
|
127 |
+
this Public License, the exchange of the Licensed Material for
|
128 |
+
other material subject to Copyright and Similar Rights by digital
|
129 |
+
file-sharing or similar means is NonCommercial provided there is
|
130 |
+
no payment of monetary compensation in connection with the
|
131 |
+
exchange.
|
132 |
+
|
133 |
+
l. Share means to provide material to the public by any means or
|
134 |
+
process that requires permission under the Licensed Rights, such
|
135 |
+
as reproduction, public display, public performance, distribution,
|
136 |
+
dissemination, communication, or importation, and to make material
|
137 |
+
available to the public including in ways that members of the
|
138 |
+
public may access the material from a place and at a time
|
139 |
+
individually chosen by them.
|
140 |
+
|
141 |
+
m. Sui Generis Database Rights means rights other than copyright
|
142 |
+
resulting from Directive 96/9/EC of the European Parliament and of
|
143 |
+
the Council of 11 March 1996 on the legal protection of databases,
|
144 |
+
as amended and/or succeeded, as well as other essentially
|
145 |
+
equivalent rights anywhere in the world.
|
146 |
+
|
147 |
+
n. You means the individual or entity exercising the Licensed Rights
|
148 |
+
under this Public License. Your has a corresponding meaning.
|
149 |
+
|
150 |
+
|
151 |
+
Section 2 -- Scope.
|
152 |
+
|
153 |
+
a. License grant.
|
154 |
+
|
155 |
+
1. Subject to the terms and conditions of this Public License,
|
156 |
+
the Licensor hereby grants You a worldwide, royalty-free,
|
157 |
+
non-sublicensable, non-exclusive, irrevocable license to
|
158 |
+
exercise the Licensed Rights in the Licensed Material to:
|
159 |
+
|
160 |
+
a. reproduce and Share the Licensed Material, in whole or
|
161 |
+
in part, for NonCommercial purposes only; and
|
162 |
+
|
163 |
+
b. produce, reproduce, and Share Adapted Material for
|
164 |
+
NonCommercial purposes only.
|
165 |
+
|
166 |
+
2. Exceptions and Limitations. For the avoidance of doubt, where
|
167 |
+
Exceptions and Limitations apply to Your use, this Public
|
168 |
+
License does not apply, and You do not need to comply with
|
169 |
+
its terms and conditions.
|
170 |
+
|
171 |
+
3. Term. The term of this Public License is specified in Section
|
172 |
+
6(a).
|
173 |
+
|
174 |
+
4. Media and formats; technical modifications allowed. The
|
175 |
+
Licensor authorizes You to exercise the Licensed Rights in
|
176 |
+
all media and formats whether now known or hereafter created,
|
177 |
+
and to make technical modifications necessary to do so. The
|
178 |
+
Licensor waives and/or agrees not to assert any right or
|
179 |
+
authority to forbid You from making technical modifications
|
180 |
+
necessary to exercise the Licensed Rights, including
|
181 |
+
technical modifications necessary to circumvent Effective
|
182 |
+
Technological Measures. For purposes of this Public License,
|
183 |
+
simply making modifications authorized by this Section 2(a)
|
184 |
+
(4) never produces Adapted Material.
|
185 |
+
|
186 |
+
5. Downstream recipients.
|
187 |
+
|
188 |
+
a. Offer from the Licensor -- Licensed Material. Every
|
189 |
+
recipient of the Licensed Material automatically
|
190 |
+
receives an offer from the Licensor to exercise the
|
191 |
+
Licensed Rights under the terms and conditions of this
|
192 |
+
Public License.
|
193 |
+
|
194 |
+
b. Additional offer from the Licensor -- Adapted Material.
|
195 |
+
Every recipient of Adapted Material from You
|
196 |
+
automatically receives an offer from the Licensor to
|
197 |
+
exercise the Licensed Rights in the Adapted Material
|
198 |
+
under the conditions of the Adapter's License You apply.
|
199 |
+
|
200 |
+
c. No downstream restrictions. You may not offer or impose
|
201 |
+
any additional or different terms or conditions on, or
|
202 |
+
apply any Effective Technological Measures to, the
|
203 |
+
Licensed Material if doing so restricts exercise of the
|
204 |
+
Licensed Rights by any recipient of the Licensed
|
205 |
+
Material.
|
206 |
+
|
207 |
+
6. No endorsement. Nothing in this Public License constitutes or
|
208 |
+
may be construed as permission to assert or imply that You
|
209 |
+
are, or that Your use of the Licensed Material is, connected
|
210 |
+
with, or sponsored, endorsed, or granted official status by,
|
211 |
+
the Licensor or others designated to receive attribution as
|
212 |
+
provided in Section 3(a)(1)(A)(i).
|
213 |
+
|
214 |
+
b. Other rights.
|
215 |
+
|
216 |
+
1. Moral rights, such as the right of integrity, are not
|
217 |
+
licensed under this Public License, nor are publicity,
|
218 |
+
privacy, and/or other similar personality rights; however, to
|
219 |
+
the extent possible, the Licensor waives and/or agrees not to
|
220 |
+
assert any such rights held by the Licensor to the limited
|
221 |
+
extent necessary to allow You to exercise the Licensed
|
222 |
+
Rights, but not otherwise.
|
223 |
+
|
224 |
+
2. Patent and trademark rights are not licensed under this
|
225 |
+
Public License.
|
226 |
+
|
227 |
+
3. To the extent possible, the Licensor waives any right to
|
228 |
+
collect royalties from You for the exercise of the Licensed
|
229 |
+
Rights, whether directly or through a collecting society
|
230 |
+
under any voluntary or waivable statutory or compulsory
|
231 |
+
licensing scheme. In all other cases the Licensor expressly
|
232 |
+
reserves any right to collect such royalties, including when
|
233 |
+
the Licensed Material is used other than for NonCommercial
|
234 |
+
purposes.
|
235 |
+
|
236 |
+
|
237 |
+
Section 3 -- License Conditions.
|
238 |
+
|
239 |
+
Your exercise of the Licensed Rights is expressly made subject to the
|
240 |
+
following conditions.
|
241 |
+
|
242 |
+
a. Attribution.
|
243 |
+
|
244 |
+
1. If You Share the Licensed Material (including in modified
|
245 |
+
form), You must:
|
246 |
+
|
247 |
+
a. retain the following if it is supplied by the Licensor
|
248 |
+
with the Licensed Material:
|
249 |
+
|
250 |
+
i. identification of the creator(s) of the Licensed
|
251 |
+
Material and any others designated to receive
|
252 |
+
attribution, in any reasonable manner requested by
|
253 |
+
the Licensor (including by pseudonym if
|
254 |
+
designated);
|
255 |
+
|
256 |
+
ii. a copyright notice;
|
257 |
+
|
258 |
+
iii. a notice that refers to this Public License;
|
259 |
+
|
260 |
+
iv. a notice that refers to the disclaimer of
|
261 |
+
warranties;
|
262 |
+
|
263 |
+
v. a URI or hyperlink to the Licensed Material to the
|
264 |
+
extent reasonably practicable;
|
265 |
+
|
266 |
+
b. indicate if You modified the Licensed Material and
|
267 |
+
retain an indication of any previous modifications; and
|
268 |
+
|
269 |
+
c. indicate the Licensed Material is licensed under this
|
270 |
+
Public License, and include the text of, or the URI or
|
271 |
+
hyperlink to, this Public License.
|
272 |
+
|
273 |
+
2. You may satisfy the conditions in Section 3(a)(1) in any
|
274 |
+
reasonable manner based on the medium, means, and context in
|
275 |
+
which You Share the Licensed Material. For example, it may be
|
276 |
+
reasonable to satisfy the conditions by providing a URI or
|
277 |
+
hyperlink to a resource that includes the required
|
278 |
+
information.
|
279 |
+
3. If requested by the Licensor, You must remove any of the
|
280 |
+
information required by Section 3(a)(1)(A) to the extent
|
281 |
+
reasonably practicable.
|
282 |
+
|
283 |
+
b. ShareAlike.
|
284 |
+
|
285 |
+
In addition to the conditions in Section 3(a), if You Share
|
286 |
+
Adapted Material You produce, the following conditions also apply.
|
287 |
+
|
288 |
+
1. The Adapter's License You apply must be a Creative Commons
|
289 |
+
license with the same License Elements, this version or
|
290 |
+
later, or a BY-NC-SA Compatible License.
|
291 |
+
|
292 |
+
2. You must include the text of, or the URI or hyperlink to, the
|
293 |
+
Adapter's License You apply. You may satisfy this condition
|
294 |
+
in any reasonable manner based on the medium, means, and
|
295 |
+
context in which You Share Adapted Material.
|
296 |
+
|
297 |
+
3. You may not offer or impose any additional or different terms
|
298 |
+
or conditions on, or apply any Effective Technological
|
299 |
+
Measures to, Adapted Material that restrict exercise of the
|
300 |
+
rights granted under the Adapter's License You apply.
|
301 |
+
|
302 |
+
|
303 |
+
Section 4 -- Sui Generis Database Rights.
|
304 |
+
|
305 |
+
Where the Licensed Rights include Sui Generis Database Rights that
|
306 |
+
apply to Your use of the Licensed Material:
|
307 |
+
|
308 |
+
a. for the avoidance of doubt, Section 2(a)(1) grants You the right
|
309 |
+
to extract, reuse, reproduce, and Share all or a substantial
|
310 |
+
portion of the contents of the database for NonCommercial purposes
|
311 |
+
only;
|
312 |
+
|
313 |
+
b. if You include all or a substantial portion of the database
|
314 |
+
contents in a database in which You have Sui Generis Database
|
315 |
+
Rights, then the database in which You have Sui Generis Database
|
316 |
+
Rights (but not its individual contents) is Adapted Material,
|
317 |
+
including for purposes of Section 3(b); and
|
318 |
+
|
319 |
+
c. You must comply with the conditions in Section 3(a) if You Share
|
320 |
+
all or a substantial portion of the contents of the database.
|
321 |
+
|
322 |
+
For the avoidance of doubt, this Section 4 supplements and does not
|
323 |
+
replace Your obligations under this Public License where the Licensed
|
324 |
+
Rights include other Copyright and Similar Rights.
|
325 |
+
|
326 |
+
|
327 |
+
Section 5 -- Disclaimer of Warranties and Limitation of Liability.
|
328 |
+
|
329 |
+
a. UNLESS OTHERWISE SEPARATELY UNDERTAKEN BY THE LICENSOR, TO THE
|
330 |
+
EXTENT POSSIBLE, THE LICENSOR OFFERS THE LICENSED MATERIAL AS-IS
|
331 |
+
AND AS-AVAILABLE, AND MAKES NO REPRESENTATIONS OR WARRANTIES OF
|
332 |
+
ANY KIND CONCERNING THE LICENSED MATERIAL, WHETHER EXPRESS,
|
333 |
+
IMPLIED, STATUTORY, OR OTHER. THIS INCLUDES, WITHOUT LIMITATION,
|
334 |
+
WARRANTIES OF TITLE, MERCHANTABILITY, FITNESS FOR A PARTICULAR
|
335 |
+
PURPOSE, NON-INFRINGEMENT, ABSENCE OF LATENT OR OTHER DEFECTS,
|
336 |
+
ACCURACY, OR THE PRESENCE OR ABSENCE OF ERRORS, WHETHER OR NOT
|
337 |
+
KNOWN OR DISCOVERABLE. WHERE DISCLAIMERS OF WARRANTIES ARE NOT
|
338 |
+
ALLOWED IN FULL OR IN PART, THIS DISCLAIMER MAY NOT APPLY TO YOU.
|
339 |
+
|
340 |
+
b. TO THE EXTENT POSSIBLE, IN NO EVENT WILL THE LICENSOR BE LIABLE
|
341 |
+
TO YOU ON ANY LEGAL THEORY (INCLUDING, WITHOUT LIMITATION,
|
342 |
+
NEGLIGENCE) OR OTHERWISE FOR ANY DIRECT, SPECIAL, INDIRECT,
|
343 |
+
INCIDENTAL, CONSEQUENTIAL, PUNITIVE, EXEMPLARY, OR OTHER LOSSES,
|
344 |
+
COSTS, EXPENSES, OR DAMAGES ARISING OUT OF THIS PUBLIC LICENSE OR
|
345 |
+
USE OF THE LICENSED MATERIAL, EVEN IF THE LICENSOR HAS BEEN
|
346 |
+
ADVISED OF THE POSSIBILITY OF SUCH LOSSES, COSTS, EXPENSES, OR
|
347 |
+
DAMAGES. WHERE A LIMITATION OF LIABILITY IS NOT ALLOWED IN FULL OR
|
348 |
+
IN PART, THIS LIMITATION MAY NOT APPLY TO YOU.
|
349 |
+
|
350 |
+
c. The disclaimer of warranties and limitation of liability provided
|
351 |
+
above shall be interpreted in a manner that, to the extent
|
352 |
+
possible, most closely approximates an absolute disclaimer and
|
353 |
+
waiver of all liability.
|
354 |
+
|
355 |
+
|
356 |
+
Section 6 -- Term and Termination.
|
357 |
+
|
358 |
+
a. This Public License applies for the term of the Copyright and
|
359 |
+
Similar Rights licensed here. However, if You fail to comply with
|
360 |
+
this Public License, then Your rights under this Public License
|
361 |
+
terminate automatically.
|
362 |
+
|
363 |
+
b. Where Your right to use the Licensed Material has terminated under
|
364 |
+
Section 6(a), it reinstates:
|
365 |
+
|
366 |
+
1. automatically as of the date the violation is cured, provided
|
367 |
+
it is cured within 30 days of Your discovery of the
|
368 |
+
violation; or
|
369 |
+
|
370 |
+
2. upon express reinstatement by the Licensor.
|
371 |
+
|
372 |
+
For the avoidance of doubt, this Section 6(b) does not affect any
|
373 |
+
right the Licensor may have to seek remedies for Your violations
|
374 |
+
of this Public License.
|
375 |
+
|
376 |
+
c. For the avoidance of doubt, the Licensor may also offer the
|
377 |
+
Licensed Material under separate terms or conditions or stop
|
378 |
+
distributing the Licensed Material at any time; however, doing so
|
379 |
+
will not terminate this Public License.
|
380 |
+
|
381 |
+
d. Sections 1, 5, 6, 7, and 8 survive termination of this Public
|
382 |
+
License.
|
383 |
+
|
384 |
+
|
385 |
+
Section 7 -- Other Terms and Conditions.
|
386 |
+
|
387 |
+
a. The Licensor shall not be bound by any additional or different
|
388 |
+
terms or conditions communicated by You unless expressly agreed.
|
389 |
+
|
390 |
+
b. Any arrangements, understandings, or agreements regarding the
|
391 |
+
Licensed Material not stated herein are separate from and
|
392 |
+
independent of the terms and conditions of this Public License.
|
393 |
+
|
394 |
+
|
395 |
+
Section 8 -- Interpretation.
|
396 |
+
|
397 |
+
a. For the avoidance of doubt, this Public License does not, and
|
398 |
+
shall not be interpreted to, reduce, limit, restrict, or impose
|
399 |
+
conditions on any use of the Licensed Material that could lawfully
|
400 |
+
be made without permission under this Public License.
|
401 |
+
|
402 |
+
b. To the extent possible, if any provision of this Public License is
|
403 |
+
deemed unenforceable, it shall be automatically reformed to the
|
404 |
+
minimum extent necessary to make it enforceable. If the provision
|
405 |
+
cannot be reformed, it shall be severed from this Public License
|
406 |
+
without affecting the enforceability of the remaining terms and
|
407 |
+
conditions.
|
408 |
+
|
409 |
+
c. No term or condition of this Public License will be waived and no
|
410 |
+
failure to comply consented to unless expressly agreed to by the
|
411 |
+
Licensor.
|
412 |
+
|
413 |
+
d. Nothing in this Public License constitutes or may be interpreted
|
414 |
+
as a limitation upon, or waiver of, any privileges and immunities
|
415 |
+
that apply to the Licensor or You, including from the legal
|
416 |
+
processes of any jurisdiction or authority.
|
417 |
+
|
418 |
+
=======================================================================
|
419 |
+
|
420 |
+
Creative Commons is not a party to its public
|
421 |
+
licenses. Notwithstanding, Creative Commons may elect to apply one of
|
422 |
+
its public licenses to material it publishes and in those instances
|
423 |
+
will be considered the “Licensor.” The text of the Creative Commons
|
424 |
+
public licenses is dedicated to the public domain under the CC0 Public
|
425 |
+
Domain Dedication. Except for the limited purpose of indicating that
|
426 |
+
material is shared under a Creative Commons public license or as
|
427 |
+
otherwise permitted by the Creative Commons policies published at
|
428 |
+
creativecommons.org/policies, Creative Commons does not authorize the
|
429 |
+
use of the trademark "Creative Commons" or any other trademark or logo
|
430 |
+
of Creative Commons without its prior written consent including,
|
431 |
+
without limitation, in connection with any unauthorized modifications
|
432 |
+
to any of its public licenses or any other arrangements,
|
433 |
+
understandings, or agreements concerning use of licensed material. For
|
434 |
+
the avoidance of doubt, this paragraph does not form part of the
|
435 |
+
public licenses.
|
436 |
+
|
437 |
+
Creative Commons may be contacted at creativecommons.org.
|
U2Net_/model/__init__.py
ADDED
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
1 |
+
from .u2net import U2NET
|
2 |
+
from .u2net import U2NETP
|
U2Net_/model/u2net.py
ADDED
@@ -0,0 +1,525 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
import torch.nn.functional as F
|
4 |
+
|
5 |
+
class REBNCONV(nn.Module):
|
6 |
+
def __init__(self,in_ch=3,out_ch=3,dirate=1):
|
7 |
+
super(REBNCONV,self).__init__()
|
8 |
+
|
9 |
+
self.conv_s1 = nn.Conv2d(in_ch,out_ch,3,padding=1*dirate,dilation=1*dirate)
|
10 |
+
self.bn_s1 = nn.BatchNorm2d(out_ch)
|
11 |
+
self.relu_s1 = nn.ReLU(inplace=True)
|
12 |
+
|
13 |
+
def forward(self,x):
|
14 |
+
|
15 |
+
hx = x
|
16 |
+
xout = self.relu_s1(self.bn_s1(self.conv_s1(hx)))
|
17 |
+
|
18 |
+
return xout
|
19 |
+
|
20 |
+
## upsample tensor 'src' to have the same spatial size with tensor 'tar'
|
21 |
+
def _upsample_like(src,tar):
|
22 |
+
|
23 |
+
src = F.upsample(src,size=tar.shape[2:],mode='bilinear')
|
24 |
+
|
25 |
+
return src
|
26 |
+
|
27 |
+
|
28 |
+
### RSU-7 ###
|
29 |
+
class RSU7(nn.Module):#UNet07DRES(nn.Module):
|
30 |
+
|
31 |
+
def __init__(self, in_ch=3, mid_ch=12, out_ch=3):
|
32 |
+
super(RSU7,self).__init__()
|
33 |
+
|
34 |
+
self.rebnconvin = REBNCONV(in_ch,out_ch,dirate=1)
|
35 |
+
|
36 |
+
self.rebnconv1 = REBNCONV(out_ch,mid_ch,dirate=1)
|
37 |
+
self.pool1 = nn.MaxPool2d(2,stride=2,ceil_mode=True)
|
38 |
+
|
39 |
+
self.rebnconv2 = REBNCONV(mid_ch,mid_ch,dirate=1)
|
40 |
+
self.pool2 = nn.MaxPool2d(2,stride=2,ceil_mode=True)
|
41 |
+
|
42 |
+
self.rebnconv3 = REBNCONV(mid_ch,mid_ch,dirate=1)
|
43 |
+
self.pool3 = nn.MaxPool2d(2,stride=2,ceil_mode=True)
|
44 |
+
|
45 |
+
self.rebnconv4 = REBNCONV(mid_ch,mid_ch,dirate=1)
|
46 |
+
self.pool4 = nn.MaxPool2d(2,stride=2,ceil_mode=True)
|
47 |
+
|
48 |
+
self.rebnconv5 = REBNCONV(mid_ch,mid_ch,dirate=1)
|
49 |
+
self.pool5 = nn.MaxPool2d(2,stride=2,ceil_mode=True)
|
50 |
+
|
51 |
+
self.rebnconv6 = REBNCONV(mid_ch,mid_ch,dirate=1)
|
52 |
+
|
53 |
+
self.rebnconv7 = REBNCONV(mid_ch,mid_ch,dirate=2)
|
54 |
+
|
55 |
+
self.rebnconv6d = REBNCONV(mid_ch*2,mid_ch,dirate=1)
|
56 |
+
self.rebnconv5d = REBNCONV(mid_ch*2,mid_ch,dirate=1)
|
57 |
+
self.rebnconv4d = REBNCONV(mid_ch*2,mid_ch,dirate=1)
|
58 |
+
self.rebnconv3d = REBNCONV(mid_ch*2,mid_ch,dirate=1)
|
59 |
+
self.rebnconv2d = REBNCONV(mid_ch*2,mid_ch,dirate=1)
|
60 |
+
self.rebnconv1d = REBNCONV(mid_ch*2,out_ch,dirate=1)
|
61 |
+
|
62 |
+
def forward(self,x):
|
63 |
+
|
64 |
+
hx = x
|
65 |
+
hxin = self.rebnconvin(hx)
|
66 |
+
|
67 |
+
hx1 = self.rebnconv1(hxin)
|
68 |
+
hx = self.pool1(hx1)
|
69 |
+
|
70 |
+
hx2 = self.rebnconv2(hx)
|
71 |
+
hx = self.pool2(hx2)
|
72 |
+
|
73 |
+
hx3 = self.rebnconv3(hx)
|
74 |
+
hx = self.pool3(hx3)
|
75 |
+
|
76 |
+
hx4 = self.rebnconv4(hx)
|
77 |
+
hx = self.pool4(hx4)
|
78 |
+
|
79 |
+
hx5 = self.rebnconv5(hx)
|
80 |
+
hx = self.pool5(hx5)
|
81 |
+
|
82 |
+
hx6 = self.rebnconv6(hx)
|
83 |
+
|
84 |
+
hx7 = self.rebnconv7(hx6)
|
85 |
+
|
86 |
+
hx6d = self.rebnconv6d(torch.cat((hx7,hx6),1))
|
87 |
+
hx6dup = _upsample_like(hx6d,hx5)
|
88 |
+
|
89 |
+
hx5d = self.rebnconv5d(torch.cat((hx6dup,hx5),1))
|
90 |
+
hx5dup = _upsample_like(hx5d,hx4)
|
91 |
+
|
92 |
+
hx4d = self.rebnconv4d(torch.cat((hx5dup,hx4),1))
|
93 |
+
hx4dup = _upsample_like(hx4d,hx3)
|
94 |
+
|
95 |
+
hx3d = self.rebnconv3d(torch.cat((hx4dup,hx3),1))
|
96 |
+
hx3dup = _upsample_like(hx3d,hx2)
|
97 |
+
|
98 |
+
hx2d = self.rebnconv2d(torch.cat((hx3dup,hx2),1))
|
99 |
+
hx2dup = _upsample_like(hx2d,hx1)
|
100 |
+
|
101 |
+
hx1d = self.rebnconv1d(torch.cat((hx2dup,hx1),1))
|
102 |
+
|
103 |
+
return hx1d + hxin
|
104 |
+
|
105 |
+
### RSU-6 ###
|
106 |
+
class RSU6(nn.Module):#UNet06DRES(nn.Module):
|
107 |
+
|
108 |
+
def __init__(self, in_ch=3, mid_ch=12, out_ch=3):
|
109 |
+
super(RSU6,self).__init__()
|
110 |
+
|
111 |
+
self.rebnconvin = REBNCONV(in_ch,out_ch,dirate=1)
|
112 |
+
|
113 |
+
self.rebnconv1 = REBNCONV(out_ch,mid_ch,dirate=1)
|
114 |
+
self.pool1 = nn.MaxPool2d(2,stride=2,ceil_mode=True)
|
115 |
+
|
116 |
+
self.rebnconv2 = REBNCONV(mid_ch,mid_ch,dirate=1)
|
117 |
+
self.pool2 = nn.MaxPool2d(2,stride=2,ceil_mode=True)
|
118 |
+
|
119 |
+
self.rebnconv3 = REBNCONV(mid_ch,mid_ch,dirate=1)
|
120 |
+
self.pool3 = nn.MaxPool2d(2,stride=2,ceil_mode=True)
|
121 |
+
|
122 |
+
self.rebnconv4 = REBNCONV(mid_ch,mid_ch,dirate=1)
|
123 |
+
self.pool4 = nn.MaxPool2d(2,stride=2,ceil_mode=True)
|
124 |
+
|
125 |
+
self.rebnconv5 = REBNCONV(mid_ch,mid_ch,dirate=1)
|
126 |
+
|
127 |
+
self.rebnconv6 = REBNCONV(mid_ch,mid_ch,dirate=2)
|
128 |
+
|
129 |
+
self.rebnconv5d = REBNCONV(mid_ch*2,mid_ch,dirate=1)
|
130 |
+
self.rebnconv4d = REBNCONV(mid_ch*2,mid_ch,dirate=1)
|
131 |
+
self.rebnconv3d = REBNCONV(mid_ch*2,mid_ch,dirate=1)
|
132 |
+
self.rebnconv2d = REBNCONV(mid_ch*2,mid_ch,dirate=1)
|
133 |
+
self.rebnconv1d = REBNCONV(mid_ch*2,out_ch,dirate=1)
|
134 |
+
|
135 |
+
def forward(self,x):
|
136 |
+
|
137 |
+
hx = x
|
138 |
+
|
139 |
+
hxin = self.rebnconvin(hx)
|
140 |
+
|
141 |
+
hx1 = self.rebnconv1(hxin)
|
142 |
+
hx = self.pool1(hx1)
|
143 |
+
|
144 |
+
hx2 = self.rebnconv2(hx)
|
145 |
+
hx = self.pool2(hx2)
|
146 |
+
|
147 |
+
hx3 = self.rebnconv3(hx)
|
148 |
+
hx = self.pool3(hx3)
|
149 |
+
|
150 |
+
hx4 = self.rebnconv4(hx)
|
151 |
+
hx = self.pool4(hx4)
|
152 |
+
|
153 |
+
hx5 = self.rebnconv5(hx)
|
154 |
+
|
155 |
+
hx6 = self.rebnconv6(hx5)
|
156 |
+
|
157 |
+
|
158 |
+
hx5d = self.rebnconv5d(torch.cat((hx6,hx5),1))
|
159 |
+
hx5dup = _upsample_like(hx5d,hx4)
|
160 |
+
|
161 |
+
hx4d = self.rebnconv4d(torch.cat((hx5dup,hx4),1))
|
162 |
+
hx4dup = _upsample_like(hx4d,hx3)
|
163 |
+
|
164 |
+
hx3d = self.rebnconv3d(torch.cat((hx4dup,hx3),1))
|
165 |
+
hx3dup = _upsample_like(hx3d,hx2)
|
166 |
+
|
167 |
+
hx2d = self.rebnconv2d(torch.cat((hx3dup,hx2),1))
|
168 |
+
hx2dup = _upsample_like(hx2d,hx1)
|
169 |
+
|
170 |
+
hx1d = self.rebnconv1d(torch.cat((hx2dup,hx1),1))
|
171 |
+
|
172 |
+
return hx1d + hxin
|
173 |
+
|
174 |
+
### RSU-5 ###
|
175 |
+
class RSU5(nn.Module):#UNet05DRES(nn.Module):
|
176 |
+
|
177 |
+
def __init__(self, in_ch=3, mid_ch=12, out_ch=3):
|
178 |
+
super(RSU5,self).__init__()
|
179 |
+
|
180 |
+
self.rebnconvin = REBNCONV(in_ch,out_ch,dirate=1)
|
181 |
+
|
182 |
+
self.rebnconv1 = REBNCONV(out_ch,mid_ch,dirate=1)
|
183 |
+
self.pool1 = nn.MaxPool2d(2,stride=2,ceil_mode=True)
|
184 |
+
|
185 |
+
self.rebnconv2 = REBNCONV(mid_ch,mid_ch,dirate=1)
|
186 |
+
self.pool2 = nn.MaxPool2d(2,stride=2,ceil_mode=True)
|
187 |
+
|
188 |
+
self.rebnconv3 = REBNCONV(mid_ch,mid_ch,dirate=1)
|
189 |
+
self.pool3 = nn.MaxPool2d(2,stride=2,ceil_mode=True)
|
190 |
+
|
191 |
+
self.rebnconv4 = REBNCONV(mid_ch,mid_ch,dirate=1)
|
192 |
+
|
193 |
+
self.rebnconv5 = REBNCONV(mid_ch,mid_ch,dirate=2)
|
194 |
+
|
195 |
+
self.rebnconv4d = REBNCONV(mid_ch*2,mid_ch,dirate=1)
|
196 |
+
self.rebnconv3d = REBNCONV(mid_ch*2,mid_ch,dirate=1)
|
197 |
+
self.rebnconv2d = REBNCONV(mid_ch*2,mid_ch,dirate=1)
|
198 |
+
self.rebnconv1d = REBNCONV(mid_ch*2,out_ch,dirate=1)
|
199 |
+
|
200 |
+
def forward(self,x):
|
201 |
+
|
202 |
+
hx = x
|
203 |
+
|
204 |
+
hxin = self.rebnconvin(hx)
|
205 |
+
|
206 |
+
hx1 = self.rebnconv1(hxin)
|
207 |
+
hx = self.pool1(hx1)
|
208 |
+
|
209 |
+
hx2 = self.rebnconv2(hx)
|
210 |
+
hx = self.pool2(hx2)
|
211 |
+
|
212 |
+
hx3 = self.rebnconv3(hx)
|
213 |
+
hx = self.pool3(hx3)
|
214 |
+
|
215 |
+
hx4 = self.rebnconv4(hx)
|
216 |
+
|
217 |
+
hx5 = self.rebnconv5(hx4)
|
218 |
+
|
219 |
+
hx4d = self.rebnconv4d(torch.cat((hx5,hx4),1))
|
220 |
+
hx4dup = _upsample_like(hx4d,hx3)
|
221 |
+
|
222 |
+
hx3d = self.rebnconv3d(torch.cat((hx4dup,hx3),1))
|
223 |
+
hx3dup = _upsample_like(hx3d,hx2)
|
224 |
+
|
225 |
+
hx2d = self.rebnconv2d(torch.cat((hx3dup,hx2),1))
|
226 |
+
hx2dup = _upsample_like(hx2d,hx1)
|
227 |
+
|
228 |
+
hx1d = self.rebnconv1d(torch.cat((hx2dup,hx1),1))
|
229 |
+
|
230 |
+
return hx1d + hxin
|
231 |
+
|
232 |
+
### RSU-4 ###
|
233 |
+
class RSU4(nn.Module):#UNet04DRES(nn.Module):
|
234 |
+
|
235 |
+
def __init__(self, in_ch=3, mid_ch=12, out_ch=3):
|
236 |
+
super(RSU4,self).__init__()
|
237 |
+
|
238 |
+
self.rebnconvin = REBNCONV(in_ch,out_ch,dirate=1)
|
239 |
+
|
240 |
+
self.rebnconv1 = REBNCONV(out_ch,mid_ch,dirate=1)
|
241 |
+
self.pool1 = nn.MaxPool2d(2,stride=2,ceil_mode=True)
|
242 |
+
|
243 |
+
self.rebnconv2 = REBNCONV(mid_ch,mid_ch,dirate=1)
|
244 |
+
self.pool2 = nn.MaxPool2d(2,stride=2,ceil_mode=True)
|
245 |
+
|
246 |
+
self.rebnconv3 = REBNCONV(mid_ch,mid_ch,dirate=1)
|
247 |
+
|
248 |
+
self.rebnconv4 = REBNCONV(mid_ch,mid_ch,dirate=2)
|
249 |
+
|
250 |
+
self.rebnconv3d = REBNCONV(mid_ch*2,mid_ch,dirate=1)
|
251 |
+
self.rebnconv2d = REBNCONV(mid_ch*2,mid_ch,dirate=1)
|
252 |
+
self.rebnconv1d = REBNCONV(mid_ch*2,out_ch,dirate=1)
|
253 |
+
|
254 |
+
def forward(self,x):
|
255 |
+
|
256 |
+
hx = x
|
257 |
+
|
258 |
+
hxin = self.rebnconvin(hx)
|
259 |
+
|
260 |
+
hx1 = self.rebnconv1(hxin)
|
261 |
+
hx = self.pool1(hx1)
|
262 |
+
|
263 |
+
hx2 = self.rebnconv2(hx)
|
264 |
+
hx = self.pool2(hx2)
|
265 |
+
|
266 |
+
hx3 = self.rebnconv3(hx)
|
267 |
+
|
268 |
+
hx4 = self.rebnconv4(hx3)
|
269 |
+
|
270 |
+
hx3d = self.rebnconv3d(torch.cat((hx4,hx3),1))
|
271 |
+
hx3dup = _upsample_like(hx3d,hx2)
|
272 |
+
|
273 |
+
hx2d = self.rebnconv2d(torch.cat((hx3dup,hx2),1))
|
274 |
+
hx2dup = _upsample_like(hx2d,hx1)
|
275 |
+
|
276 |
+
hx1d = self.rebnconv1d(torch.cat((hx2dup,hx1),1))
|
277 |
+
|
278 |
+
return hx1d + hxin
|
279 |
+
|
280 |
+
### RSU-4F ###
|
281 |
+
class RSU4F(nn.Module):#UNet04FRES(nn.Module):
|
282 |
+
|
283 |
+
def __init__(self, in_ch=3, mid_ch=12, out_ch=3):
|
284 |
+
super(RSU4F,self).__init__()
|
285 |
+
|
286 |
+
self.rebnconvin = REBNCONV(in_ch,out_ch,dirate=1)
|
287 |
+
|
288 |
+
self.rebnconv1 = REBNCONV(out_ch,mid_ch,dirate=1)
|
289 |
+
self.rebnconv2 = REBNCONV(mid_ch,mid_ch,dirate=2)
|
290 |
+
self.rebnconv3 = REBNCONV(mid_ch,mid_ch,dirate=4)
|
291 |
+
|
292 |
+
self.rebnconv4 = REBNCONV(mid_ch,mid_ch,dirate=8)
|
293 |
+
|
294 |
+
self.rebnconv3d = REBNCONV(mid_ch*2,mid_ch,dirate=4)
|
295 |
+
self.rebnconv2d = REBNCONV(mid_ch*2,mid_ch,dirate=2)
|
296 |
+
self.rebnconv1d = REBNCONV(mid_ch*2,out_ch,dirate=1)
|
297 |
+
|
298 |
+
def forward(self,x):
|
299 |
+
|
300 |
+
hx = x
|
301 |
+
|
302 |
+
hxin = self.rebnconvin(hx)
|
303 |
+
|
304 |
+
hx1 = self.rebnconv1(hxin)
|
305 |
+
hx2 = self.rebnconv2(hx1)
|
306 |
+
hx3 = self.rebnconv3(hx2)
|
307 |
+
|
308 |
+
hx4 = self.rebnconv4(hx3)
|
309 |
+
|
310 |
+
hx3d = self.rebnconv3d(torch.cat((hx4,hx3),1))
|
311 |
+
hx2d = self.rebnconv2d(torch.cat((hx3d,hx2),1))
|
312 |
+
hx1d = self.rebnconv1d(torch.cat((hx2d,hx1),1))
|
313 |
+
|
314 |
+
return hx1d + hxin
|
315 |
+
|
316 |
+
|
317 |
+
##### U^2-Net ####
|
318 |
+
class U2NET(nn.Module):
|
319 |
+
|
320 |
+
def __init__(self,in_ch=3,out_ch=1):
|
321 |
+
super(U2NET,self).__init__()
|
322 |
+
|
323 |
+
self.stage1 = RSU7(in_ch,32,64)
|
324 |
+
self.pool12 = nn.MaxPool2d(2,stride=2,ceil_mode=True)
|
325 |
+
|
326 |
+
self.stage2 = RSU6(64,32,128)
|
327 |
+
self.pool23 = nn.MaxPool2d(2,stride=2,ceil_mode=True)
|
328 |
+
|
329 |
+
self.stage3 = RSU5(128,64,256)
|
330 |
+
self.pool34 = nn.MaxPool2d(2,stride=2,ceil_mode=True)
|
331 |
+
|
332 |
+
self.stage4 = RSU4(256,128,512)
|
333 |
+
self.pool45 = nn.MaxPool2d(2,stride=2,ceil_mode=True)
|
334 |
+
|
335 |
+
self.stage5 = RSU4F(512,256,512)
|
336 |
+
self.pool56 = nn.MaxPool2d(2,stride=2,ceil_mode=True)
|
337 |
+
|
338 |
+
self.stage6 = RSU4F(512,256,512)
|
339 |
+
|
340 |
+
# decoder
|
341 |
+
self.stage5d = RSU4F(1024,256,512)
|
342 |
+
self.stage4d = RSU4(1024,128,256)
|
343 |
+
self.stage3d = RSU5(512,64,128)
|
344 |
+
self.stage2d = RSU6(256,32,64)
|
345 |
+
self.stage1d = RSU7(128,16,64)
|
346 |
+
|
347 |
+
self.side1 = nn.Conv2d(64,out_ch,3,padding=1)
|
348 |
+
self.side2 = nn.Conv2d(64,out_ch,3,padding=1)
|
349 |
+
self.side3 = nn.Conv2d(128,out_ch,3,padding=1)
|
350 |
+
self.side4 = nn.Conv2d(256,out_ch,3,padding=1)
|
351 |
+
self.side5 = nn.Conv2d(512,out_ch,3,padding=1)
|
352 |
+
self.side6 = nn.Conv2d(512,out_ch,3,padding=1)
|
353 |
+
|
354 |
+
self.outconv = nn.Conv2d(6*out_ch,out_ch,1)
|
355 |
+
|
356 |
+
def forward(self,x):
|
357 |
+
|
358 |
+
hx = x
|
359 |
+
|
360 |
+
#stage 1
|
361 |
+
hx1 = self.stage1(hx)
|
362 |
+
hx = self.pool12(hx1)
|
363 |
+
|
364 |
+
#stage 2
|
365 |
+
hx2 = self.stage2(hx)
|
366 |
+
hx = self.pool23(hx2)
|
367 |
+
|
368 |
+
#stage 3
|
369 |
+
hx3 = self.stage3(hx)
|
370 |
+
hx = self.pool34(hx3)
|
371 |
+
|
372 |
+
#stage 4
|
373 |
+
hx4 = self.stage4(hx)
|
374 |
+
hx = self.pool45(hx4)
|
375 |
+
|
376 |
+
#stage 5
|
377 |
+
hx5 = self.stage5(hx)
|
378 |
+
hx = self.pool56(hx5)
|
379 |
+
|
380 |
+
#stage 6
|
381 |
+
hx6 = self.stage6(hx)
|
382 |
+
hx6up = _upsample_like(hx6,hx5)
|
383 |
+
|
384 |
+
#-------------------- decoder --------------------
|
385 |
+
hx5d = self.stage5d(torch.cat((hx6up,hx5),1))
|
386 |
+
hx5dup = _upsample_like(hx5d,hx4)
|
387 |
+
|
388 |
+
hx4d = self.stage4d(torch.cat((hx5dup,hx4),1))
|
389 |
+
hx4dup = _upsample_like(hx4d,hx3)
|
390 |
+
|
391 |
+
hx3d = self.stage3d(torch.cat((hx4dup,hx3),1))
|
392 |
+
hx3dup = _upsample_like(hx3d,hx2)
|
393 |
+
|
394 |
+
hx2d = self.stage2d(torch.cat((hx3dup,hx2),1))
|
395 |
+
hx2dup = _upsample_like(hx2d,hx1)
|
396 |
+
|
397 |
+
hx1d = self.stage1d(torch.cat((hx2dup,hx1),1))
|
398 |
+
|
399 |
+
|
400 |
+
#side output
|
401 |
+
d1 = self.side1(hx1d)
|
402 |
+
|
403 |
+
d2 = self.side2(hx2d)
|
404 |
+
d2 = _upsample_like(d2,d1)
|
405 |
+
|
406 |
+
d3 = self.side3(hx3d)
|
407 |
+
d3 = _upsample_like(d3,d1)
|
408 |
+
|
409 |
+
d4 = self.side4(hx4d)
|
410 |
+
d4 = _upsample_like(d4,d1)
|
411 |
+
|
412 |
+
d5 = self.side5(hx5d)
|
413 |
+
d5 = _upsample_like(d5,d1)
|
414 |
+
|
415 |
+
d6 = self.side6(hx6)
|
416 |
+
d6 = _upsample_like(d6,d1)
|
417 |
+
|
418 |
+
d0 = self.outconv(torch.cat((d1,d2,d3,d4,d5,d6),1))
|
419 |
+
|
420 |
+
return F.sigmoid(d0), F.sigmoid(d1), F.sigmoid(d2), F.sigmoid(d3), F.sigmoid(d4), F.sigmoid(d5), F.sigmoid(d6)
|
421 |
+
|
422 |
+
### U^2-Net small ###
|
423 |
+
class U2NETP(nn.Module):
|
424 |
+
|
425 |
+
def __init__(self,in_ch=3,out_ch=1):
|
426 |
+
super(U2NETP,self).__init__()
|
427 |
+
|
428 |
+
self.stage1 = RSU7(in_ch,16,64)
|
429 |
+
self.pool12 = nn.MaxPool2d(2,stride=2,ceil_mode=True)
|
430 |
+
|
431 |
+
self.stage2 = RSU6(64,16,64)
|
432 |
+
self.pool23 = nn.MaxPool2d(2,stride=2,ceil_mode=True)
|
433 |
+
|
434 |
+
self.stage3 = RSU5(64,16,64)
|
435 |
+
self.pool34 = nn.MaxPool2d(2,stride=2,ceil_mode=True)
|
436 |
+
|
437 |
+
self.stage4 = RSU4(64,16,64)
|
438 |
+
self.pool45 = nn.MaxPool2d(2,stride=2,ceil_mode=True)
|
439 |
+
|
440 |
+
self.stage5 = RSU4F(64,16,64)
|
441 |
+
self.pool56 = nn.MaxPool2d(2,stride=2,ceil_mode=True)
|
442 |
+
|
443 |
+
self.stage6 = RSU4F(64,16,64)
|
444 |
+
|
445 |
+
# decoder
|
446 |
+
self.stage5d = RSU4F(128,16,64)
|
447 |
+
self.stage4d = RSU4(128,16,64)
|
448 |
+
self.stage3d = RSU5(128,16,64)
|
449 |
+
self.stage2d = RSU6(128,16,64)
|
450 |
+
self.stage1d = RSU7(128,16,64)
|
451 |
+
|
452 |
+
self.side1 = nn.Conv2d(64,out_ch,3,padding=1)
|
453 |
+
self.side2 = nn.Conv2d(64,out_ch,3,padding=1)
|
454 |
+
self.side3 = nn.Conv2d(64,out_ch,3,padding=1)
|
455 |
+
self.side4 = nn.Conv2d(64,out_ch,3,padding=1)
|
456 |
+
self.side5 = nn.Conv2d(64,out_ch,3,padding=1)
|
457 |
+
self.side6 = nn.Conv2d(64,out_ch,3,padding=1)
|
458 |
+
|
459 |
+
self.outconv = nn.Conv2d(6*out_ch,out_ch,1)
|
460 |
+
|
461 |
+
def forward(self,x):
|
462 |
+
|
463 |
+
hx = x
|
464 |
+
|
465 |
+
#stage 1
|
466 |
+
hx1 = self.stage1(hx)
|
467 |
+
hx = self.pool12(hx1)
|
468 |
+
|
469 |
+
#stage 2
|
470 |
+
hx2 = self.stage2(hx)
|
471 |
+
hx = self.pool23(hx2)
|
472 |
+
|
473 |
+
#stage 3
|
474 |
+
hx3 = self.stage3(hx)
|
475 |
+
hx = self.pool34(hx3)
|
476 |
+
|
477 |
+
#stage 4
|
478 |
+
hx4 = self.stage4(hx)
|
479 |
+
hx = self.pool45(hx4)
|
480 |
+
|
481 |
+
#stage 5
|
482 |
+
hx5 = self.stage5(hx)
|
483 |
+
hx = self.pool56(hx5)
|
484 |
+
|
485 |
+
#stage 6
|
486 |
+
hx6 = self.stage6(hx)
|
487 |
+
hx6up = _upsample_like(hx6,hx5)
|
488 |
+
|
489 |
+
#decoder
|
490 |
+
hx5d = self.stage5d(torch.cat((hx6up,hx5),1))
|
491 |
+
hx5dup = _upsample_like(hx5d,hx4)
|
492 |
+
|
493 |
+
hx4d = self.stage4d(torch.cat((hx5dup,hx4),1))
|
494 |
+
hx4dup = _upsample_like(hx4d,hx3)
|
495 |
+
|
496 |
+
hx3d = self.stage3d(torch.cat((hx4dup,hx3),1))
|
497 |
+
hx3dup = _upsample_like(hx3d,hx2)
|
498 |
+
|
499 |
+
hx2d = self.stage2d(torch.cat((hx3dup,hx2),1))
|
500 |
+
hx2dup = _upsample_like(hx2d,hx1)
|
501 |
+
|
502 |
+
hx1d = self.stage1d(torch.cat((hx2dup,hx1),1))
|
503 |
+
|
504 |
+
|
505 |
+
#side output
|
506 |
+
d1 = self.side1(hx1d)
|
507 |
+
|
508 |
+
d2 = self.side2(hx2d)
|
509 |
+
d2 = _upsample_like(d2,d1)
|
510 |
+
|
511 |
+
d3 = self.side3(hx3d)
|
512 |
+
d3 = _upsample_like(d3,d1)
|
513 |
+
|
514 |
+
d4 = self.side4(hx4d)
|
515 |
+
d4 = _upsample_like(d4,d1)
|
516 |
+
|
517 |
+
d5 = self.side5(hx5d)
|
518 |
+
d5 = _upsample_like(d5,d1)
|
519 |
+
|
520 |
+
d6 = self.side6(hx6)
|
521 |
+
d6 = _upsample_like(d6,d1)
|
522 |
+
|
523 |
+
d0 = self.outconv(torch.cat((d1,d2,d3,d4,d5,d6),1))
|
524 |
+
|
525 |
+
return F.sigmoid(d0), F.sigmoid(d1), F.sigmoid(d2), F.sigmoid(d3), F.sigmoid(d4), F.sigmoid(d5), F.sigmoid(d6)
|
U2Net_/model/u2net_refactor.py
ADDED
@@ -0,0 +1,168 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
|
4 |
+
import math
|
5 |
+
|
6 |
+
__all__ = ['U2NET_full', 'U2NET_lite']
|
7 |
+
|
8 |
+
|
9 |
+
def _upsample_like(x, size):
|
10 |
+
return nn.Upsample(size=size, mode='bilinear', align_corners=False)(x)
|
11 |
+
|
12 |
+
|
13 |
+
def _size_map(x, height):
|
14 |
+
# {height: size} for Upsample
|
15 |
+
size = list(x.shape[-2:])
|
16 |
+
sizes = {}
|
17 |
+
for h in range(1, height):
|
18 |
+
sizes[h] = size
|
19 |
+
size = [math.ceil(w / 2) for w in size]
|
20 |
+
return sizes
|
21 |
+
|
22 |
+
|
23 |
+
class REBNCONV(nn.Module):
|
24 |
+
def __init__(self, in_ch=3, out_ch=3, dilate=1):
|
25 |
+
super(REBNCONV, self).__init__()
|
26 |
+
|
27 |
+
self.conv_s1 = nn.Conv2d(in_ch, out_ch, 3, padding=1 * dilate, dilation=1 * dilate)
|
28 |
+
self.bn_s1 = nn.BatchNorm2d(out_ch)
|
29 |
+
self.relu_s1 = nn.ReLU(inplace=True)
|
30 |
+
|
31 |
+
def forward(self, x):
|
32 |
+
return self.relu_s1(self.bn_s1(self.conv_s1(x)))
|
33 |
+
|
34 |
+
|
35 |
+
class RSU(nn.Module):
|
36 |
+
def __init__(self, name, height, in_ch, mid_ch, out_ch, dilated=False):
|
37 |
+
super(RSU, self).__init__()
|
38 |
+
self.name = name
|
39 |
+
self.height = height
|
40 |
+
self.dilated = dilated
|
41 |
+
self._make_layers(height, in_ch, mid_ch, out_ch, dilated)
|
42 |
+
|
43 |
+
def forward(self, x):
|
44 |
+
sizes = _size_map(x, self.height)
|
45 |
+
x = self.rebnconvin(x)
|
46 |
+
|
47 |
+
# U-Net like symmetric encoder-decoder structure
|
48 |
+
def unet(x, height=1):
|
49 |
+
if height < self.height:
|
50 |
+
x1 = getattr(self, f'rebnconv{height}')(x)
|
51 |
+
if not self.dilated and height < self.height - 1:
|
52 |
+
x2 = unet(getattr(self, 'downsample')(x1), height + 1)
|
53 |
+
else:
|
54 |
+
x2 = unet(x1, height + 1)
|
55 |
+
|
56 |
+
x = getattr(self, f'rebnconv{height}d')(torch.cat((x2, x1), 1))
|
57 |
+
return _upsample_like(x, sizes[height - 1]) if not self.dilated and height > 1 else x
|
58 |
+
else:
|
59 |
+
return getattr(self, f'rebnconv{height}')(x)
|
60 |
+
|
61 |
+
return x + unet(x)
|
62 |
+
|
63 |
+
def _make_layers(self, height, in_ch, mid_ch, out_ch, dilated=False):
|
64 |
+
self.add_module('rebnconvin', REBNCONV(in_ch, out_ch))
|
65 |
+
self.add_module('downsample', nn.MaxPool2d(2, stride=2, ceil_mode=True))
|
66 |
+
|
67 |
+
self.add_module(f'rebnconv1', REBNCONV(out_ch, mid_ch))
|
68 |
+
self.add_module(f'rebnconv1d', REBNCONV(mid_ch * 2, out_ch))
|
69 |
+
|
70 |
+
for i in range(2, height):
|
71 |
+
dilate = 1 if not dilated else 2 ** (i - 1)
|
72 |
+
self.add_module(f'rebnconv{i}', REBNCONV(mid_ch, mid_ch, dilate=dilate))
|
73 |
+
self.add_module(f'rebnconv{i}d', REBNCONV(mid_ch * 2, mid_ch, dilate=dilate))
|
74 |
+
|
75 |
+
dilate = 2 if not dilated else 2 ** (height - 1)
|
76 |
+
self.add_module(f'rebnconv{height}', REBNCONV(mid_ch, mid_ch, dilate=dilate))
|
77 |
+
|
78 |
+
|
79 |
+
class U2NET(nn.Module):
|
80 |
+
def __init__(self, cfgs, out_ch):
|
81 |
+
super(U2NET, self).__init__()
|
82 |
+
self.out_ch = out_ch
|
83 |
+
self._make_layers(cfgs)
|
84 |
+
|
85 |
+
def forward(self, x):
|
86 |
+
sizes = _size_map(x, self.height)
|
87 |
+
maps = [] # storage for maps
|
88 |
+
|
89 |
+
# side saliency map
|
90 |
+
def unet(x, height=1):
|
91 |
+
if height < 6:
|
92 |
+
x1 = getattr(self, f'stage{height}')(x)
|
93 |
+
x2 = unet(getattr(self, 'downsample')(x1), height + 1)
|
94 |
+
x = getattr(self, f'stage{height}d')(torch.cat((x2, x1), 1))
|
95 |
+
side(x, height)
|
96 |
+
return _upsample_like(x, sizes[height - 1]) if height > 1 else x
|
97 |
+
else:
|
98 |
+
x = getattr(self, f'stage{height}')(x)
|
99 |
+
side(x, height)
|
100 |
+
return _upsample_like(x, sizes[height - 1])
|
101 |
+
|
102 |
+
def side(x, h):
|
103 |
+
# side output saliency map (before sigmoid)
|
104 |
+
x = getattr(self, f'side{h}')(x)
|
105 |
+
x = _upsample_like(x, sizes[1])
|
106 |
+
maps.append(x)
|
107 |
+
|
108 |
+
def fuse():
|
109 |
+
# fuse saliency probability maps
|
110 |
+
maps.reverse()
|
111 |
+
x = torch.cat(maps, 1)
|
112 |
+
x = getattr(self, 'outconv')(x)
|
113 |
+
maps.insert(0, x)
|
114 |
+
return [torch.sigmoid(x) for x in maps]
|
115 |
+
|
116 |
+
unet(x)
|
117 |
+
maps = fuse()
|
118 |
+
return maps
|
119 |
+
|
120 |
+
def _make_layers(self, cfgs):
|
121 |
+
self.height = int((len(cfgs) + 1) / 2)
|
122 |
+
self.add_module('downsample', nn.MaxPool2d(2, stride=2, ceil_mode=True))
|
123 |
+
for k, v in cfgs.items():
|
124 |
+
# build rsu block
|
125 |
+
self.add_module(k, RSU(v[0], *v[1]))
|
126 |
+
if v[2] > 0:
|
127 |
+
# build side layer
|
128 |
+
self.add_module(f'side{v[0][-1]}', nn.Conv2d(v[2], self.out_ch, 3, padding=1))
|
129 |
+
# build fuse layer
|
130 |
+
self.add_module('outconv', nn.Conv2d(int(self.height * self.out_ch), self.out_ch, 1))
|
131 |
+
|
132 |
+
|
133 |
+
def U2NET_full():
|
134 |
+
full = {
|
135 |
+
# cfgs for building RSUs and sides
|
136 |
+
# {stage : [name, (height(L), in_ch, mid_ch, out_ch, dilated), side]}
|
137 |
+
'stage1': ['En_1', (7, 3, 32, 64), -1],
|
138 |
+
'stage2': ['En_2', (6, 64, 32, 128), -1],
|
139 |
+
'stage3': ['En_3', (5, 128, 64, 256), -1],
|
140 |
+
'stage4': ['En_4', (4, 256, 128, 512), -1],
|
141 |
+
'stage5': ['En_5', (4, 512, 256, 512, True), -1],
|
142 |
+
'stage6': ['En_6', (4, 512, 256, 512, True), 512],
|
143 |
+
'stage5d': ['De_5', (4, 1024, 256, 512, True), 512],
|
144 |
+
'stage4d': ['De_4', (4, 1024, 128, 256), 256],
|
145 |
+
'stage3d': ['De_3', (5, 512, 64, 128), 128],
|
146 |
+
'stage2d': ['De_2', (6, 256, 32, 64), 64],
|
147 |
+
'stage1d': ['De_1', (7, 128, 16, 64), 64],
|
148 |
+
}
|
149 |
+
return U2NET(cfgs=full, out_ch=1)
|
150 |
+
|
151 |
+
|
152 |
+
def U2NET_lite():
|
153 |
+
lite = {
|
154 |
+
# cfgs for building RSUs and sides
|
155 |
+
# {stage : [name, (height(L), in_ch, mid_ch, out_ch, dilated), side]}
|
156 |
+
'stage1': ['En_1', (7, 3, 16, 64), -1],
|
157 |
+
'stage2': ['En_2', (6, 64, 16, 64), -1],
|
158 |
+
'stage3': ['En_3', (5, 64, 16, 64), -1],
|
159 |
+
'stage4': ['En_4', (4, 64, 16, 64), -1],
|
160 |
+
'stage5': ['En_5', (4, 64, 16, 64, True), -1],
|
161 |
+
'stage6': ['En_6', (4, 64, 16, 64, True), 64],
|
162 |
+
'stage5d': ['De_5', (4, 128, 16, 64, True), 64],
|
163 |
+
'stage4d': ['De_4', (4, 128, 16, 64), 64],
|
164 |
+
'stage3d': ['De_3', (5, 128, 16, 64), 64],
|
165 |
+
'stage2d': ['De_2', (6, 128, 16, 64), 64],
|
166 |
+
'stage1d': ['De_1', (7, 128, 16, 64), 64],
|
167 |
+
}
|
168 |
+
return U2NET(cfgs=lite, out_ch=1)
|
U2Net_/saved_models/face_detection_cv2/haarcascade_frontalface_default.xml
ADDED
The diff for this file is too large to render.
See raw diff
|
|
cog.yaml
ADDED
@@ -0,0 +1,54 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Configuration for Cog ⚙️
|
2 |
+
# Reference: https://github.com/replicate/cog/blob/main/docs/yaml.md
|
3 |
+
|
4 |
+
build:
|
5 |
+
# set to true if your model requires a GPU
|
6 |
+
gpu: true
|
7 |
+
cuda: "10.1"
|
8 |
+
|
9 |
+
# a list of ubuntu apt packages to install
|
10 |
+
system_packages:
|
11 |
+
#- "libopenmpi-dev"
|
12 |
+
- "libgl1-mesa-glx"
|
13 |
+
- "libglib2.0-0"
|
14 |
+
# - "cmake-"
|
15 |
+
|
16 |
+
# python version in the form '3.8' or '3.8.12'
|
17 |
+
python_version: "3.7.9"
|
18 |
+
|
19 |
+
# a list of packages in the format <package-name>==<version>
|
20 |
+
python_packages:
|
21 |
+
# cmake==3.21.2
|
22 |
+
# - "pip==21.2.2"
|
23 |
+
- "cmake==3.14.3"
|
24 |
+
- "torch==1.7.1"
|
25 |
+
- "torchvision==0.8.2"
|
26 |
+
- "numpy==1.19.2"
|
27 |
+
- "ipython==7.21.0"
|
28 |
+
- "Pillow==8.3.1"
|
29 |
+
- "svgwrite==1.4.1"
|
30 |
+
- "svgpathtools==1.4.1"
|
31 |
+
- "cssutils==2.3.0"
|
32 |
+
- "numba==0.55.1"
|
33 |
+
- "torch-tools==0.1.5"
|
34 |
+
- "visdom==0.1.8.9"
|
35 |
+
- "ftfy==6.1.1"
|
36 |
+
- "regex==2021.8.28"
|
37 |
+
- "tqdm==4.62.3"
|
38 |
+
- "scikit-image==0.18.3"
|
39 |
+
- "gdown==4.4.0"
|
40 |
+
- "wandb==0.12.0"
|
41 |
+
- "tensorflow-gpu==1.15.2"
|
42 |
+
|
43 |
+
# commands run after the environment is setup
|
44 |
+
run:
|
45 |
+
# - /root/.pyenv/versions/3.7.9/bin/python3.7 -m pip install --upgrade pip
|
46 |
+
- export PYTHONPATH="/diffvg/build/lib.linux-x86_64-3.7"
|
47 |
+
- git clone https://github.com/BachiLi/diffvg && cd diffvg && git submodule update --init --recursive && CMAKE_PREFIX_PATH=$(pyenv prefix) DIFFVG_CUDA=1 python setup.py install
|
48 |
+
- pip install git+https://github.com/openai/CLIP.git --no-deps
|
49 |
+
- gdown "https://drive.google.com/uc?id=1ao1ovG1Qtx4b7EoskHXmi2E9rp5CHLcZ" -O "/src/U2Net_/saved_models/"
|
50 |
+
- "echo env is ready!"
|
51 |
+
- "echo another command if needed"
|
52 |
+
|
53 |
+
# predict.py defines how predictions are run on your model
|
54 |
+
predict: "predict.py:Predictor"
|
config.py
ADDED
@@ -0,0 +1,144 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import argparse
|
2 |
+
import os
|
3 |
+
import random
|
4 |
+
|
5 |
+
import numpy as np
|
6 |
+
import pydiffvg
|
7 |
+
import torch
|
8 |
+
import wandb
|
9 |
+
|
10 |
+
|
11 |
+
def set_seed(seed):
|
12 |
+
random.seed(seed)
|
13 |
+
np.random.seed(seed)
|
14 |
+
os.environ['PYTHONHASHSEED'] = str(seed)
|
15 |
+
torch.manual_seed(seed)
|
16 |
+
torch.cuda.manual_seed(seed)
|
17 |
+
torch.cuda.manual_seed_all(seed)
|
18 |
+
|
19 |
+
|
20 |
+
def parse_arguments():
|
21 |
+
parser = argparse.ArgumentParser()
|
22 |
+
# =================================
|
23 |
+
# ============ general ============
|
24 |
+
# =================================
|
25 |
+
parser.add_argument("target", help="target image path")
|
26 |
+
parser.add_argument("--output_dir", type=str,
|
27 |
+
help="directory to save the output images and loss")
|
28 |
+
parser.add_argument("--path_svg", type=str, default="none",
|
29 |
+
help="if you want to load an svg file and train from it")
|
30 |
+
parser.add_argument("--use_gpu", type=int, default=0)
|
31 |
+
parser.add_argument("--seed", type=int, default=0)
|
32 |
+
parser.add_argument("--mask_object", type=int, default=0)
|
33 |
+
parser.add_argument("--fix_scale", type=int, default=0)
|
34 |
+
parser.add_argument("--display_logs", type=int, default=0)
|
35 |
+
parser.add_argument("--display", type=int, default=0)
|
36 |
+
|
37 |
+
# =================================
|
38 |
+
# ============ wandb ============
|
39 |
+
# =================================
|
40 |
+
parser.add_argument("--use_wandb", type=int, default=0)
|
41 |
+
parser.add_argument("--wandb_user", type=str, default="yael-vinker")
|
42 |
+
parser.add_argument("--wandb_name", type=str, default="test")
|
43 |
+
parser.add_argument("--wandb_project_name", type=str, default="none")
|
44 |
+
|
45 |
+
# =================================
|
46 |
+
# =========== training ============
|
47 |
+
# =================================
|
48 |
+
parser.add_argument("--num_iter", type=int, default=500,
|
49 |
+
help="number of optimization iterations")
|
50 |
+
parser.add_argument("--num_stages", type=int, default=1,
|
51 |
+
help="training stages, you can train x strokes, then freeze them and train another x strokes etc.")
|
52 |
+
parser.add_argument("--lr_scheduler", type=int, default=0)
|
53 |
+
parser.add_argument("--lr", type=float, default=1.0)
|
54 |
+
parser.add_argument("--color_lr", type=float, default=0.01)
|
55 |
+
parser.add_argument("--color_vars_threshold", type=float, default=0.0)
|
56 |
+
parser.add_argument("--batch_size", type=int, default=1,
|
57 |
+
help="for optimization it's only one image")
|
58 |
+
parser.add_argument("--save_interval", type=int, default=10)
|
59 |
+
parser.add_argument("--eval_interval", type=int, default=10)
|
60 |
+
parser.add_argument("--image_scale", type=int, default=224)
|
61 |
+
|
62 |
+
# =================================
|
63 |
+
# ======== strokes params =========
|
64 |
+
# =================================
|
65 |
+
parser.add_argument("--num_paths", type=int,
|
66 |
+
default=16, help="number of strokes")
|
67 |
+
parser.add_argument("--width", type=float,
|
68 |
+
default=1.5, help="stroke width")
|
69 |
+
parser.add_argument("--control_points_per_seg", type=int, default=4)
|
70 |
+
parser.add_argument("--num_segments", type=int, default=1,
|
71 |
+
help="number of segments for each stroke, each stroke is a bezier curve with 4 control points")
|
72 |
+
parser.add_argument("--attention_init", type=int, default=1,
|
73 |
+
help="if True, use the attention heads of Dino model to set the location of the initial strokes")
|
74 |
+
parser.add_argument("--saliency_model", type=str, default="clip")
|
75 |
+
parser.add_argument("--saliency_clip_model", type=str, default="ViT-B/32")
|
76 |
+
parser.add_argument("--xdog_intersec", type=int, default=1)
|
77 |
+
parser.add_argument("--mask_object_attention", type=int, default=0)
|
78 |
+
parser.add_argument("--softmax_temp", type=float, default=0.3)
|
79 |
+
|
80 |
+
# =================================
|
81 |
+
# ============= loss ==============
|
82 |
+
# =================================
|
83 |
+
parser.add_argument("--percep_loss", type=str, default="none",
|
84 |
+
help="the type of perceptual loss to be used (L2/LPIPS/none)")
|
85 |
+
parser.add_argument("--perceptual_weight", type=float, default=0,
|
86 |
+
help="weight the perceptual loss")
|
87 |
+
parser.add_argument("--train_with_clip", type=int, default=0)
|
88 |
+
parser.add_argument("--clip_weight", type=float, default=0)
|
89 |
+
parser.add_argument("--start_clip", type=int, default=0)
|
90 |
+
parser.add_argument("--num_aug_clip", type=int, default=4)
|
91 |
+
parser.add_argument("--include_target_in_aug", type=int, default=0)
|
92 |
+
parser.add_argument("--augment_both", type=int, default=1,
|
93 |
+
help="if you want to apply the affine augmentation to both the sketch and image")
|
94 |
+
parser.add_argument("--augemntations", type=str, default="affine",
|
95 |
+
help="can be any combination of: 'affine_noise_eraserchunks_eraser_press'")
|
96 |
+
parser.add_argument("--noise_thresh", type=float, default=0.5)
|
97 |
+
parser.add_argument("--aug_scale_min", type=float, default=0.7)
|
98 |
+
parser.add_argument("--force_sparse", type=float, default=0,
|
99 |
+
help="if True, use L1 regularization on stroke's opacity to encourage small number of strokes")
|
100 |
+
parser.add_argument("--clip_conv_loss", type=float, default=1)
|
101 |
+
parser.add_argument("--clip_conv_loss_type", type=str, default="L2")
|
102 |
+
parser.add_argument("--clip_conv_layer_weights",
|
103 |
+
type=str, default="0,0,1.0,1.0,0")
|
104 |
+
parser.add_argument("--clip_model_name", type=str, default="RN101")
|
105 |
+
parser.add_argument("--clip_fc_loss_weight", type=float, default=0.1)
|
106 |
+
parser.add_argument("--clip_text_guide", type=float, default=0)
|
107 |
+
parser.add_argument("--text_target", type=str, default="none")
|
108 |
+
|
109 |
+
args = parser.parse_args()
|
110 |
+
set_seed(args.seed)
|
111 |
+
|
112 |
+
args.clip_conv_layer_weights = [
|
113 |
+
float(item) for item in args.clip_conv_layer_weights.split(',')]
|
114 |
+
|
115 |
+
args.output_dir = os.path.join(args.output_dir, args.wandb_name)
|
116 |
+
if not os.path.exists(args.output_dir):
|
117 |
+
os.mkdir(args.output_dir)
|
118 |
+
|
119 |
+
jpg_logs_dir = f"{args.output_dir}/jpg_logs"
|
120 |
+
svg_logs_dir = f"{args.output_dir}/svg_logs"
|
121 |
+
if not os.path.exists(jpg_logs_dir):
|
122 |
+
os.mkdir(jpg_logs_dir)
|
123 |
+
if not os.path.exists(svg_logs_dir):
|
124 |
+
os.mkdir(svg_logs_dir)
|
125 |
+
|
126 |
+
if args.use_wandb:
|
127 |
+
wandb.init(project=args.wandb_project_name, entity=args.wandb_user,
|
128 |
+
config=args, name=args.wandb_name, id=wandb.util.generate_id())
|
129 |
+
|
130 |
+
if args.use_gpu:
|
131 |
+
args.device = torch.device("cuda" if (
|
132 |
+
torch.cuda.is_available() and torch.cuda.device_count() > 0) else "cpu")
|
133 |
+
else:
|
134 |
+
args.device = torch.device("cpu")
|
135 |
+
pydiffvg.set_use_gpu(torch.cuda.is_available() and args.use_gpu)
|
136 |
+
pydiffvg.set_device(args.device)
|
137 |
+
return args
|
138 |
+
|
139 |
+
|
140 |
+
if __name__ == "__main__":
|
141 |
+
# for cog predict
|
142 |
+
args = parse_arguments()
|
143 |
+
final_config = vars(args)
|
144 |
+
np.save(f"{args.output_dir}/config_init.npy", final_config)
|
display_results.py
ADDED
@@ -0,0 +1,94 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import argparse
|
2 |
+
import os
|
3 |
+
import re
|
4 |
+
|
5 |
+
import imageio
|
6 |
+
import matplotlib.pyplot as plt
|
7 |
+
import moviepy.editor as mvp
|
8 |
+
import numpy as np
|
9 |
+
import pydiffvg
|
10 |
+
import torch
|
11 |
+
from IPython.display import Image as Image_colab
|
12 |
+
from IPython.display import display, SVG
|
13 |
+
from PIL import Image
|
14 |
+
|
15 |
+
parser = argparse.ArgumentParser()
|
16 |
+
parser.add_argument("--target_file", type=str,
|
17 |
+
help="target image file, located in <target_images>")
|
18 |
+
parser.add_argument("--num_strokes", type=int)
|
19 |
+
args = parser.parse_args()
|
20 |
+
|
21 |
+
|
22 |
+
def read_svg(path_svg, multiply=False):
|
23 |
+
device = torch.device("cuda" if (
|
24 |
+
torch.cuda.is_available() and torch.cuda.device_count() > 0) else "cpu")
|
25 |
+
canvas_width, canvas_height, shapes, shape_groups = pydiffvg.svg_to_scene(
|
26 |
+
path_svg)
|
27 |
+
if multiply:
|
28 |
+
canvas_width *= 2
|
29 |
+
canvas_height *= 2
|
30 |
+
for path in shapes:
|
31 |
+
path.points *= 2
|
32 |
+
path.stroke_width *= 2
|
33 |
+
_render = pydiffvg.RenderFunction.apply
|
34 |
+
scene_args = pydiffvg.RenderFunction.serialize_scene(
|
35 |
+
canvas_width, canvas_height, shapes, shape_groups)
|
36 |
+
img = _render(canvas_width, # width
|
37 |
+
canvas_height, # height
|
38 |
+
2, # num_samples_x
|
39 |
+
2, # num_samples_y
|
40 |
+
0, # seed
|
41 |
+
None,
|
42 |
+
*scene_args)
|
43 |
+
img = img[:, :, 3:4] * img[:, :, :3] + \
|
44 |
+
torch.ones(img.shape[0], img.shape[1], 3,
|
45 |
+
device=device) * (1 - img[:, :, 3:4])
|
46 |
+
img = img[:, :, :3]
|
47 |
+
return img
|
48 |
+
|
49 |
+
|
50 |
+
abs_path = os.path.abspath(os.getcwd())
|
51 |
+
|
52 |
+
result_path = f"{abs_path}/output_sketches/{os.path.splitext(args.target_file)[0]}"
|
53 |
+
svg_files = os.listdir(result_path)
|
54 |
+
svg_files = [f for f in svg_files if "best.svg" in f and f"{args.num_strokes}strokes" in f]
|
55 |
+
svg_output_path = f"{result_path}/{svg_files[0]}"
|
56 |
+
|
57 |
+
target_path = f"{svg_output_path[:-9]}/input.png"
|
58 |
+
|
59 |
+
sketch_res = read_svg(svg_output_path, multiply=True).cpu().numpy()
|
60 |
+
sketch_res = Image.fromarray((sketch_res * 255).astype('uint8'), 'RGB')
|
61 |
+
|
62 |
+
input_im = Image.open(target_path).resize((224,224))
|
63 |
+
display(input_im)
|
64 |
+
display(SVG(svg_output_path))
|
65 |
+
|
66 |
+
p = re.compile("_best")
|
67 |
+
best_sketch_dir = ""
|
68 |
+
for m in p.finditer(svg_files[0]):
|
69 |
+
best_sketch_dir += svg_files[0][0: m.start()]
|
70 |
+
|
71 |
+
|
72 |
+
sketches = []
|
73 |
+
cur_path = f"{result_path}/{best_sketch_dir}"
|
74 |
+
sketch_res.save(f"{cur_path}/final_sketch.png")
|
75 |
+
print(f"You can download the result sketch from {cur_path}/final_sketch.png")
|
76 |
+
|
77 |
+
if not os.path.exists(f"{cur_path}/svg_to_png"):
|
78 |
+
os.mkdir(f"{cur_path}/svg_to_png")
|
79 |
+
if os.path.exists(f"{cur_path}/config.npy"):
|
80 |
+
config = np.load(f"{cur_path}/config.npy", allow_pickle=True)[()]
|
81 |
+
inter = config["save_interval"]
|
82 |
+
loss_eval = np.array(config['loss_eval'])
|
83 |
+
inds = np.argsort(loss_eval)
|
84 |
+
intervals = list(range(0, (inds[0] + 1) * inter, inter))
|
85 |
+
for i_ in intervals:
|
86 |
+
path_svg = f"{cur_path}/svg_logs/svg_iter{i_}.svg"
|
87 |
+
sketch = read_svg(path_svg, multiply=True).cpu().numpy()
|
88 |
+
sketch = Image.fromarray((sketch * 255).astype('uint8'), 'RGB')
|
89 |
+
# print("{0}/iter_{1:04}.png".format(cur_path, int(i_)))
|
90 |
+
sketch.save("{0}/{1}/iter_{2:04}.png".format(cur_path, "svg_to_png", int(i_)))
|
91 |
+
sketches.append(sketch)
|
92 |
+
imageio.mimsave(f"{cur_path}/sketch.gif", sketches)
|
93 |
+
|
94 |
+
print(cur_path)
|
models/loss.py
ADDED
@@ -0,0 +1,463 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
import collections
|
3 |
+
import CLIP_.clip as clip
|
4 |
+
import torch
|
5 |
+
import torch.nn as nn
|
6 |
+
from torchvision import models, transforms
|
7 |
+
|
8 |
+
|
9 |
+
class Loss(nn.Module):
|
10 |
+
def __init__(self, args):
|
11 |
+
super(Loss, self).__init__()
|
12 |
+
self.args = args
|
13 |
+
self.percep_loss = args.percep_loss
|
14 |
+
|
15 |
+
self.train_with_clip = args.train_with_clip
|
16 |
+
self.clip_weight = args.clip_weight
|
17 |
+
self.start_clip = args.start_clip
|
18 |
+
|
19 |
+
self.clip_conv_loss = args.clip_conv_loss
|
20 |
+
self.clip_fc_loss_weight = args.clip_fc_loss_weight
|
21 |
+
self.clip_text_guide = args.clip_text_guide
|
22 |
+
|
23 |
+
self.losses_to_apply = self.get_losses_to_apply()
|
24 |
+
|
25 |
+
self.loss_mapper = \
|
26 |
+
{
|
27 |
+
"clip": CLIPLoss(args),
|
28 |
+
"clip_conv_loss": CLIPConvLoss(args)
|
29 |
+
}
|
30 |
+
|
31 |
+
def get_losses_to_apply(self):
|
32 |
+
losses_to_apply = []
|
33 |
+
if self.percep_loss != "none":
|
34 |
+
losses_to_apply.append(self.percep_loss)
|
35 |
+
if self.train_with_clip and self.start_clip == 0:
|
36 |
+
losses_to_apply.append("clip")
|
37 |
+
if self.clip_conv_loss:
|
38 |
+
losses_to_apply.append("clip_conv_loss")
|
39 |
+
if self.clip_text_guide:
|
40 |
+
losses_to_apply.append("clip_text")
|
41 |
+
return losses_to_apply
|
42 |
+
|
43 |
+
def update_losses_to_apply(self, epoch):
|
44 |
+
if "clip" not in self.losses_to_apply:
|
45 |
+
if self.train_with_clip:
|
46 |
+
if epoch > self.start_clip:
|
47 |
+
self.losses_to_apply.append("clip")
|
48 |
+
|
49 |
+
def forward(self, sketches, targets, color_parameters, renderer, epoch, points_optim=None, mode="train"):
|
50 |
+
loss = 0
|
51 |
+
self.update_losses_to_apply(epoch)
|
52 |
+
|
53 |
+
losses_dict = dict.fromkeys(
|
54 |
+
self.losses_to_apply, torch.tensor([0.0]).to(self.args.device))
|
55 |
+
loss_coeffs = dict.fromkeys(self.losses_to_apply, 1.0)
|
56 |
+
loss_coeffs["clip"] = self.clip_weight
|
57 |
+
loss_coeffs["clip_text"] = self.clip_text_guide
|
58 |
+
|
59 |
+
for loss_name in self.losses_to_apply:
|
60 |
+
if loss_name in ["clip_conv_loss"]:
|
61 |
+
conv_loss = self.loss_mapper[loss_name](
|
62 |
+
sketches, targets, mode)
|
63 |
+
for layer in conv_loss.keys():
|
64 |
+
losses_dict[layer] = conv_loss[layer]
|
65 |
+
elif loss_name == "l2":
|
66 |
+
losses_dict[loss_name] = self.loss_mapper[loss_name](
|
67 |
+
sketches, targets).mean()
|
68 |
+
else:
|
69 |
+
losses_dict[loss_name] = self.loss_mapper[loss_name](
|
70 |
+
sketches, targets, mode).mean()
|
71 |
+
# loss = loss + self.loss_mapper[loss_name](sketches, targets).mean() * loss_coeffs[loss_name]
|
72 |
+
|
73 |
+
for key in self.losses_to_apply:
|
74 |
+
# loss = loss + losses_dict[key] * loss_coeffs[key]
|
75 |
+
losses_dict[key] = losses_dict[key] * loss_coeffs[key]
|
76 |
+
# print(losses_dict)
|
77 |
+
return losses_dict
|
78 |
+
|
79 |
+
|
80 |
+
class CLIPLoss(torch.nn.Module):
|
81 |
+
def __init__(self, args):
|
82 |
+
super(CLIPLoss, self).__init__()
|
83 |
+
|
84 |
+
self.args = args
|
85 |
+
self.model, clip_preprocess = clip.load(
|
86 |
+
'ViT-B/32', args.device, jit=False)
|
87 |
+
self.model.eval()
|
88 |
+
self.preprocess = transforms.Compose(
|
89 |
+
[clip_preprocess.transforms[-1]]) # clip normalisation
|
90 |
+
self.device = args.device
|
91 |
+
self.NUM_AUGS = args.num_aug_clip
|
92 |
+
augemntations = []
|
93 |
+
if "affine" in args.augemntations:
|
94 |
+
augemntations.append(transforms.RandomPerspective(
|
95 |
+
fill=0, p=1.0, distortion_scale=0.5))
|
96 |
+
augemntations.append(transforms.RandomResizedCrop(
|
97 |
+
224, scale=(0.8, 0.8), ratio=(1.0, 1.0)))
|
98 |
+
augemntations.append(
|
99 |
+
transforms.Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)))
|
100 |
+
self.augment_trans = transforms.Compose(augemntations)
|
101 |
+
|
102 |
+
self.calc_target = True
|
103 |
+
self.include_target_in_aug = args.include_target_in_aug
|
104 |
+
self.counter = 0
|
105 |
+
self.augment_both = args.augment_both
|
106 |
+
|
107 |
+
def forward(self, sketches, targets, mode="train"):
|
108 |
+
if self.calc_target:
|
109 |
+
targets_ = self.preprocess(targets).to(self.device)
|
110 |
+
self.targets_features = self.model.encode_image(targets_).detach()
|
111 |
+
self.calc_target = False
|
112 |
+
|
113 |
+
if mode == "eval":
|
114 |
+
# for regular clip distance, no augmentations
|
115 |
+
with torch.no_grad():
|
116 |
+
sketches = self.preprocess(sketches).to(self.device)
|
117 |
+
sketches_features = self.model.encode_image(sketches)
|
118 |
+
return 1. - torch.cosine_similarity(sketches_features, self.targets_features)
|
119 |
+
|
120 |
+
loss_clip = 0
|
121 |
+
sketch_augs = []
|
122 |
+
img_augs = []
|
123 |
+
for n in range(self.NUM_AUGS):
|
124 |
+
augmented_pair = self.augment_trans(torch.cat([sketches, targets]))
|
125 |
+
sketch_augs.append(augmented_pair[0].unsqueeze(0))
|
126 |
+
|
127 |
+
sketch_batch = torch.cat(sketch_augs)
|
128 |
+
# sketch_utils.plot_batch(img_batch, sketch_batch, self.args, self.counter, use_wandb=False, title="fc_aug{}_iter{}_{}.jpg".format(1, self.counter, mode))
|
129 |
+
# if self.counter % 100 == 0:
|
130 |
+
# sketch_utils.plot_batch(img_batch, sketch_batch, self.args, self.counter, use_wandb=False, title="aug{}_iter{}_{}.jpg".format(1, self.counter, mode))
|
131 |
+
|
132 |
+
sketch_features = self.model.encode_image(sketch_batch)
|
133 |
+
|
134 |
+
for n in range(self.NUM_AUGS):
|
135 |
+
loss_clip += (1. - torch.cosine_similarity(
|
136 |
+
sketch_features[n:n+1], self.targets_features, dim=1))
|
137 |
+
self.counter += 1
|
138 |
+
return loss_clip
|
139 |
+
# return 1. - torch.cosine_similarity(sketches_features, self.targets_features)
|
140 |
+
|
141 |
+
|
142 |
+
class LPIPS(torch.nn.Module):
|
143 |
+
def __init__(self, pretrained=True, normalize=True, pre_relu=True, device=None):
|
144 |
+
"""
|
145 |
+
Args:
|
146 |
+
pre_relu(bool): if True, selects features **before** reLU activations
|
147 |
+
"""
|
148 |
+
super(LPIPS, self).__init__()
|
149 |
+
# VGG using perceptually-learned weights (LPIPS metric)
|
150 |
+
self.normalize = normalize
|
151 |
+
self.pretrained = pretrained
|
152 |
+
augemntations = []
|
153 |
+
augemntations.append(transforms.RandomPerspective(
|
154 |
+
fill=0, p=1.0, distortion_scale=0.5))
|
155 |
+
augemntations.append(transforms.RandomResizedCrop(
|
156 |
+
224, scale=(0.8, 0.8), ratio=(1.0, 1.0)))
|
157 |
+
self.augment_trans = transforms.Compose(augemntations)
|
158 |
+
self.feature_extractor = LPIPS._FeatureExtractor(
|
159 |
+
pretrained, pre_relu).to(device)
|
160 |
+
|
161 |
+
def _l2_normalize_features(self, x, eps=1e-10):
|
162 |
+
nrm = torch.sqrt(torch.sum(x * x, dim=1, keepdim=True))
|
163 |
+
return x / (nrm + eps)
|
164 |
+
|
165 |
+
def forward(self, pred, target, mode="train"):
|
166 |
+
"""Compare VGG features of two inputs."""
|
167 |
+
|
168 |
+
# Get VGG features
|
169 |
+
|
170 |
+
sketch_augs, img_augs = [pred], [target]
|
171 |
+
if mode == "train":
|
172 |
+
for n in range(4):
|
173 |
+
augmented_pair = self.augment_trans(torch.cat([pred, target]))
|
174 |
+
sketch_augs.append(augmented_pair[0].unsqueeze(0))
|
175 |
+
img_augs.append(augmented_pair[1].unsqueeze(0))
|
176 |
+
|
177 |
+
xs = torch.cat(sketch_augs, dim=0)
|
178 |
+
ys = torch.cat(img_augs, dim=0)
|
179 |
+
|
180 |
+
pred = self.feature_extractor(xs)
|
181 |
+
target = self.feature_extractor(ys)
|
182 |
+
|
183 |
+
# L2 normalize features
|
184 |
+
if self.normalize:
|
185 |
+
pred = [self._l2_normalize_features(f) for f in pred]
|
186 |
+
target = [self._l2_normalize_features(f) for f in target]
|
187 |
+
|
188 |
+
# TODO(mgharbi) Apply Richard's linear weights?
|
189 |
+
|
190 |
+
if self.normalize:
|
191 |
+
diffs = [torch.sum((p - t) ** 2, 1)
|
192 |
+
for (p, t) in zip(pred, target)]
|
193 |
+
else:
|
194 |
+
# mean instead of sum to avoid super high range
|
195 |
+
diffs = [torch.mean((p - t) ** 2, 1)
|
196 |
+
for (p, t) in zip(pred, target)]
|
197 |
+
|
198 |
+
# Spatial average
|
199 |
+
diffs = [diff.mean([1, 2]) for diff in diffs]
|
200 |
+
|
201 |
+
return sum(diffs)
|
202 |
+
|
203 |
+
class _FeatureExtractor(torch.nn.Module):
|
204 |
+
def __init__(self, pretrained, pre_relu):
|
205 |
+
super(LPIPS._FeatureExtractor, self).__init__()
|
206 |
+
vgg_pretrained = models.vgg16(pretrained=pretrained).features
|
207 |
+
|
208 |
+
self.breakpoints = [0, 4, 9, 16, 23, 30]
|
209 |
+
if pre_relu:
|
210 |
+
for i, _ in enumerate(self.breakpoints[1:]):
|
211 |
+
self.breakpoints[i + 1] -= 1
|
212 |
+
|
213 |
+
# Split at the maxpools
|
214 |
+
for i, b in enumerate(self.breakpoints[:-1]):
|
215 |
+
ops = torch.nn.Sequential()
|
216 |
+
for idx in range(b, self.breakpoints[i + 1]):
|
217 |
+
op = vgg_pretrained[idx]
|
218 |
+
ops.add_module(str(idx), op)
|
219 |
+
# print(ops)
|
220 |
+
self.add_module("group{}".format(i), ops)
|
221 |
+
|
222 |
+
# No gradients
|
223 |
+
for p in self.parameters():
|
224 |
+
p.requires_grad = False
|
225 |
+
|
226 |
+
# Torchvision's normalization: <https://github.com/pytorch/examples/blob/42e5b996718797e45c46a25c55b031e6768f8440/imagenet/main.py#L89-L101>
|
227 |
+
self.register_buffer("shift", torch.Tensor(
|
228 |
+
[0.485, 0.456, 0.406]).view(1, 3, 1, 1))
|
229 |
+
self.register_buffer("scale", torch.Tensor(
|
230 |
+
[0.229, 0.224, 0.225]).view(1, 3, 1, 1))
|
231 |
+
|
232 |
+
def forward(self, x):
|
233 |
+
feats = []
|
234 |
+
x = (x - self.shift) / self.scale
|
235 |
+
for idx in range(len(self.breakpoints) - 1):
|
236 |
+
m = getattr(self, "group{}".format(idx))
|
237 |
+
x = m(x)
|
238 |
+
feats.append(x)
|
239 |
+
return feats
|
240 |
+
|
241 |
+
|
242 |
+
class L2_(torch.nn.Module):
|
243 |
+
def __init__(self):
|
244 |
+
"""
|
245 |
+
Args:
|
246 |
+
pre_relu(bool): if True, selects features **before** reLU activations
|
247 |
+
"""
|
248 |
+
super(L2_, self).__init__()
|
249 |
+
# VGG using perceptually-learned weights (LPIPS metric)
|
250 |
+
augemntations = []
|
251 |
+
augemntations.append(transforms.RandomPerspective(
|
252 |
+
fill=0, p=1.0, distortion_scale=0.5))
|
253 |
+
augemntations.append(transforms.RandomResizedCrop(
|
254 |
+
224, scale=(0.8, 0.8), ratio=(1.0, 1.0)))
|
255 |
+
augemntations.append(
|
256 |
+
transforms.Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)))
|
257 |
+
self.augment_trans = transforms.Compose(augemntations)
|
258 |
+
# LOG.warning("LPIPS is untested")
|
259 |
+
|
260 |
+
def forward(self, pred, target, mode="train"):
|
261 |
+
"""Compare VGG features of two inputs."""
|
262 |
+
|
263 |
+
# Get VGG features
|
264 |
+
|
265 |
+
sketch_augs, img_augs = [pred], [target]
|
266 |
+
if mode == "train":
|
267 |
+
for n in range(4):
|
268 |
+
augmented_pair = self.augment_trans(torch.cat([pred, target]))
|
269 |
+
sketch_augs.append(augmented_pair[0].unsqueeze(0))
|
270 |
+
img_augs.append(augmented_pair[1].unsqueeze(0))
|
271 |
+
|
272 |
+
pred = torch.cat(sketch_augs, dim=0)
|
273 |
+
target = torch.cat(img_augs, dim=0)
|
274 |
+
diffs = [torch.square(p - t).mean() for (p, t) in zip(pred, target)]
|
275 |
+
return sum(diffs)
|
276 |
+
|
277 |
+
|
278 |
+
class CLIPVisualEncoder(nn.Module):
|
279 |
+
def __init__(self, clip_model):
|
280 |
+
super().__init__()
|
281 |
+
self.clip_model = clip_model
|
282 |
+
self.featuremaps = None
|
283 |
+
|
284 |
+
for i in range(12): # 12 resblocks in VIT visual transformer
|
285 |
+
self.clip_model.visual.transformer.resblocks[i].register_forward_hook(
|
286 |
+
self.make_hook(i))
|
287 |
+
|
288 |
+
def make_hook(self, name):
|
289 |
+
def hook(module, input, output):
|
290 |
+
if len(output.shape) == 3:
|
291 |
+
self.featuremaps[name] = output.permute(
|
292 |
+
1, 0, 2) # LND -> NLD bs, smth, 768
|
293 |
+
else:
|
294 |
+
self.featuremaps[name] = output
|
295 |
+
|
296 |
+
return hook
|
297 |
+
|
298 |
+
def forward(self, x):
|
299 |
+
self.featuremaps = collections.OrderedDict()
|
300 |
+
fc_features = self.clip_model.encode_image(x).float()
|
301 |
+
featuremaps = [self.featuremaps[k] for k in range(12)]
|
302 |
+
|
303 |
+
return fc_features, featuremaps
|
304 |
+
|
305 |
+
|
306 |
+
def l2_layers(xs_conv_features, ys_conv_features, clip_model_name):
|
307 |
+
return [torch.square(x_conv - y_conv).mean() for x_conv, y_conv in
|
308 |
+
zip(xs_conv_features, ys_conv_features)]
|
309 |
+
|
310 |
+
|
311 |
+
def l1_layers(xs_conv_features, ys_conv_features, clip_model_name):
|
312 |
+
return [torch.abs(x_conv - y_conv).mean() for x_conv, y_conv in
|
313 |
+
zip(xs_conv_features, ys_conv_features)]
|
314 |
+
|
315 |
+
|
316 |
+
def cos_layers(xs_conv_features, ys_conv_features, clip_model_name):
|
317 |
+
if "RN" in clip_model_name:
|
318 |
+
return [torch.square(x_conv, y_conv, dim=1).mean() for x_conv, y_conv in
|
319 |
+
zip(xs_conv_features, ys_conv_features)]
|
320 |
+
return [(1 - torch.cosine_similarity(x_conv, y_conv, dim=1)).mean() for x_conv, y_conv in
|
321 |
+
zip(xs_conv_features, ys_conv_features)]
|
322 |
+
|
323 |
+
|
324 |
+
class CLIPConvLoss(torch.nn.Module):
|
325 |
+
def __init__(self, args):
|
326 |
+
super(CLIPConvLoss, self).__init__()
|
327 |
+
self.clip_model_name = args.clip_model_name
|
328 |
+
assert self.clip_model_name in [
|
329 |
+
"RN50",
|
330 |
+
"RN101",
|
331 |
+
"RN50x4",
|
332 |
+
"RN50x16",
|
333 |
+
"ViT-B/32",
|
334 |
+
"ViT-B/16",
|
335 |
+
]
|
336 |
+
|
337 |
+
self.clip_conv_loss_type = args.clip_conv_loss_type
|
338 |
+
self.clip_fc_loss_type = "Cos" # args.clip_fc_loss_type
|
339 |
+
assert self.clip_conv_loss_type in [
|
340 |
+
"L2", "Cos", "L1",
|
341 |
+
]
|
342 |
+
assert self.clip_fc_loss_type in [
|
343 |
+
"L2", "Cos", "L1",
|
344 |
+
]
|
345 |
+
|
346 |
+
self.distance_metrics = \
|
347 |
+
{
|
348 |
+
"L2": l2_layers,
|
349 |
+
"L1": l1_layers,
|
350 |
+
"Cos": cos_layers
|
351 |
+
}
|
352 |
+
|
353 |
+
self.model, clip_preprocess = clip.load(
|
354 |
+
self.clip_model_name, args.device, jit=False)
|
355 |
+
|
356 |
+
if self.clip_model_name.startswith("ViT"):
|
357 |
+
self.visual_encoder = CLIPVisualEncoder(self.model)
|
358 |
+
|
359 |
+
else:
|
360 |
+
self.visual_model = self.model.visual
|
361 |
+
layers = list(self.model.visual.children())
|
362 |
+
init_layers = torch.nn.Sequential(*layers)[:8]
|
363 |
+
self.layer1 = layers[8]
|
364 |
+
self.layer2 = layers[9]
|
365 |
+
self.layer3 = layers[10]
|
366 |
+
self.layer4 = layers[11]
|
367 |
+
self.att_pool2d = layers[12]
|
368 |
+
|
369 |
+
self.args = args
|
370 |
+
|
371 |
+
self.img_size = clip_preprocess.transforms[1].size
|
372 |
+
self.model.eval()
|
373 |
+
self.target_transform = transforms.Compose([
|
374 |
+
transforms.ToTensor(),
|
375 |
+
]) # clip normalisation
|
376 |
+
self.normalize_transform = transforms.Compose([
|
377 |
+
clip_preprocess.transforms[0], # Resize
|
378 |
+
clip_preprocess.transforms[1], # CenterCrop
|
379 |
+
clip_preprocess.transforms[-1], # Normalize
|
380 |
+
])
|
381 |
+
|
382 |
+
self.model.eval()
|
383 |
+
self.device = args.device
|
384 |
+
self.num_augs = self.args.num_aug_clip
|
385 |
+
|
386 |
+
augemntations = []
|
387 |
+
if "affine" in args.augemntations:
|
388 |
+
augemntations.append(transforms.RandomPerspective(
|
389 |
+
fill=0, p=1.0, distortion_scale=0.5))
|
390 |
+
augemntations.append(transforms.RandomResizedCrop(
|
391 |
+
224, scale=(0.8, 0.8), ratio=(1.0, 1.0)))
|
392 |
+
augemntations.append(
|
393 |
+
transforms.Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)))
|
394 |
+
self.augment_trans = transforms.Compose(augemntations)
|
395 |
+
|
396 |
+
self.clip_fc_layer_dims = None # self.args.clip_fc_layer_dims
|
397 |
+
self.clip_conv_layer_dims = None # self.args.clip_conv_layer_dims
|
398 |
+
self.clip_fc_loss_weight = args.clip_fc_loss_weight
|
399 |
+
self.counter = 0
|
400 |
+
|
401 |
+
def forward(self, sketch, target, mode="train"):
|
402 |
+
"""
|
403 |
+
Parameters
|
404 |
+
----------
|
405 |
+
sketch: Torch Tensor [1, C, H, W]
|
406 |
+
target: Torch Tensor [1, C, H, W]
|
407 |
+
"""
|
408 |
+
# y = self.target_transform(target).to(self.args.device)
|
409 |
+
conv_loss_dict = {}
|
410 |
+
x = sketch.to(self.device)
|
411 |
+
y = target.to(self.device)
|
412 |
+
sketch_augs, img_augs = [self.normalize_transform(x)], [
|
413 |
+
self.normalize_transform(y)]
|
414 |
+
if mode == "train":
|
415 |
+
for n in range(self.num_augs):
|
416 |
+
augmented_pair = self.augment_trans(torch.cat([x, y]))
|
417 |
+
sketch_augs.append(augmented_pair[0].unsqueeze(0))
|
418 |
+
img_augs.append(augmented_pair[1].unsqueeze(0))
|
419 |
+
|
420 |
+
xs = torch.cat(sketch_augs, dim=0).to(self.device)
|
421 |
+
ys = torch.cat(img_augs, dim=0).to(self.device)
|
422 |
+
|
423 |
+
if self.clip_model_name.startswith("RN"):
|
424 |
+
xs_fc_features, xs_conv_features = self.forward_inspection_clip_resnet(
|
425 |
+
xs.contiguous())
|
426 |
+
ys_fc_features, ys_conv_features = self.forward_inspection_clip_resnet(
|
427 |
+
ys.detach())
|
428 |
+
|
429 |
+
else:
|
430 |
+
xs_fc_features, xs_conv_features = self.visual_encoder(xs)
|
431 |
+
ys_fc_features, ys_conv_features = self.visual_encoder(ys)
|
432 |
+
|
433 |
+
conv_loss = self.distance_metrics[self.clip_conv_loss_type](
|
434 |
+
xs_conv_features, ys_conv_features, self.clip_model_name)
|
435 |
+
|
436 |
+
for layer, w in enumerate(self.args.clip_conv_layer_weights):
|
437 |
+
if w:
|
438 |
+
conv_loss_dict[f"clip_conv_loss_layer{layer}"] = conv_loss[layer] * w
|
439 |
+
|
440 |
+
if self.clip_fc_loss_weight:
|
441 |
+
# fc distance is always cos
|
442 |
+
fc_loss = (1 - torch.cosine_similarity(xs_fc_features,
|
443 |
+
ys_fc_features, dim=1)).mean()
|
444 |
+
conv_loss_dict["fc"] = fc_loss * self.clip_fc_loss_weight
|
445 |
+
|
446 |
+
self.counter += 1
|
447 |
+
return conv_loss_dict
|
448 |
+
|
449 |
+
def forward_inspection_clip_resnet(self, x):
|
450 |
+
def stem(m, x):
|
451 |
+
for conv, bn in [(m.conv1, m.bn1), (m.conv2, m.bn2), (m.conv3, m.bn3)]:
|
452 |
+
x = m.relu(bn(conv(x)))
|
453 |
+
x = m.avgpool(x)
|
454 |
+
return x
|
455 |
+
x = x.type(self.visual_model.conv1.weight.dtype)
|
456 |
+
x = stem(self.visual_model, x)
|
457 |
+
x1 = self.layer1(x)
|
458 |
+
x2 = self.layer2(x1)
|
459 |
+
x3 = self.layer3(x2)
|
460 |
+
x4 = self.layer4(x3)
|
461 |
+
y = self.att_pool2d(x4)
|
462 |
+
return y, [x, x1, x2, x3, x4]
|
463 |
+
|
models/painter_params.py
ADDED
@@ -0,0 +1,539 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import random
|
2 |
+
import CLIP_.clip as clip
|
3 |
+
import numpy as np
|
4 |
+
import pydiffvg
|
5 |
+
import sketch_utils as utils
|
6 |
+
import torch
|
7 |
+
import torch.nn as nn
|
8 |
+
import torch.nn.functional as F
|
9 |
+
from PIL import Image
|
10 |
+
from scipy.ndimage.filters import gaussian_filter
|
11 |
+
from skimage.color import rgb2gray
|
12 |
+
from skimage.filters import threshold_otsu
|
13 |
+
from torchvision import transforms
|
14 |
+
|
15 |
+
|
16 |
+
class Painter(torch.nn.Module):
|
17 |
+
def __init__(self, args,
|
18 |
+
num_strokes=4,
|
19 |
+
num_segments=4,
|
20 |
+
imsize=224,
|
21 |
+
device=None,
|
22 |
+
target_im=None,
|
23 |
+
mask=None):
|
24 |
+
super(Painter, self).__init__()
|
25 |
+
|
26 |
+
self.args = args
|
27 |
+
self.num_paths = num_strokes
|
28 |
+
self.num_segments = num_segments
|
29 |
+
self.width = args.width
|
30 |
+
self.control_points_per_seg = args.control_points_per_seg
|
31 |
+
self.opacity_optim = args.force_sparse
|
32 |
+
self.num_stages = args.num_stages
|
33 |
+
self.add_random_noise = "noise" in args.augemntations
|
34 |
+
self.noise_thresh = args.noise_thresh
|
35 |
+
self.softmax_temp = args.softmax_temp
|
36 |
+
|
37 |
+
self.shapes = []
|
38 |
+
self.shape_groups = []
|
39 |
+
self.device = device
|
40 |
+
self.canvas_width, self.canvas_height = imsize, imsize
|
41 |
+
self.points_vars = []
|
42 |
+
self.color_vars = []
|
43 |
+
self.color_vars_threshold = args.color_vars_threshold
|
44 |
+
|
45 |
+
self.path_svg = args.path_svg
|
46 |
+
self.strokes_per_stage = self.num_paths
|
47 |
+
self.optimize_flag = []
|
48 |
+
|
49 |
+
# attention related for strokes initialisation
|
50 |
+
self.attention_init = args.attention_init
|
51 |
+
self.target_path = args.target
|
52 |
+
self.saliency_model = args.saliency_model
|
53 |
+
self.xdog_intersec = args.xdog_intersec
|
54 |
+
self.mask_object = args.mask_object_attention
|
55 |
+
|
56 |
+
self.text_target = args.text_target # for clip gradients
|
57 |
+
self.saliency_clip_model = args.saliency_clip_model
|
58 |
+
self.define_attention_input(target_im)
|
59 |
+
self.mask = mask
|
60 |
+
self.attention_map = self.set_attention_map() if self.attention_init else None
|
61 |
+
|
62 |
+
self.thresh = self.set_attention_threshold_map() if self.attention_init else None
|
63 |
+
self.strokes_counter = 0 # counts the number of calls to "get_path"
|
64 |
+
self.epoch = 0
|
65 |
+
self.final_epoch = args.num_iter - 1
|
66 |
+
|
67 |
+
|
68 |
+
def init_image(self, stage=0):
|
69 |
+
if stage > 0:
|
70 |
+
# if multi stages training than add new strokes on existing ones
|
71 |
+
# don't optimize on previous strokes
|
72 |
+
self.optimize_flag = [False for i in range(len(self.shapes))]
|
73 |
+
for i in range(self.strokes_per_stage):
|
74 |
+
stroke_color = torch.tensor([0.0, 0.0, 0.0, 1.0])
|
75 |
+
path = self.get_path()
|
76 |
+
self.shapes.append(path)
|
77 |
+
path_group = pydiffvg.ShapeGroup(shape_ids = torch.tensor([len(self.shapes) - 1]),
|
78 |
+
fill_color = None,
|
79 |
+
stroke_color = stroke_color)
|
80 |
+
self.shape_groups.append(path_group)
|
81 |
+
self.optimize_flag.append(True)
|
82 |
+
|
83 |
+
else:
|
84 |
+
num_paths_exists = 0
|
85 |
+
if self.path_svg != "none":
|
86 |
+
self.canvas_width, self.canvas_height, self.shapes, self.shape_groups = utils.load_svg(self.path_svg)
|
87 |
+
# if you want to add more strokes to existing ones and optimize on all of them
|
88 |
+
num_paths_exists = len(self.shapes)
|
89 |
+
|
90 |
+
for i in range(num_paths_exists, self.num_paths):
|
91 |
+
stroke_color = torch.tensor([0.0, 0.0, 0.0, 1.0])
|
92 |
+
path = self.get_path()
|
93 |
+
self.shapes.append(path)
|
94 |
+
path_group = pydiffvg.ShapeGroup(shape_ids = torch.tensor([len(self.shapes) - 1]),
|
95 |
+
fill_color = None,
|
96 |
+
stroke_color = stroke_color)
|
97 |
+
self.shape_groups.append(path_group)
|
98 |
+
self.optimize_flag = [True for i in range(len(self.shapes))]
|
99 |
+
|
100 |
+
img = self.render_warp()
|
101 |
+
img = img[:, :, 3:4] * img[:, :, :3] + torch.ones(img.shape[0], img.shape[1], 3, device = self.device) * (1 - img[:, :, 3:4])
|
102 |
+
img = img[:, :, :3]
|
103 |
+
# Convert img from HWC to NCHW
|
104 |
+
img = img.unsqueeze(0)
|
105 |
+
img = img.permute(0, 3, 1, 2).to(self.device) # NHWC -> NCHW
|
106 |
+
return img
|
107 |
+
# utils.imwrite(img.cpu(), '{}/init.png'.format(args.output_dir), gamma=args.gamma, use_wandb=args.use_wandb, wandb_name="init")
|
108 |
+
|
109 |
+
def get_image(self):
|
110 |
+
img = self.render_warp()
|
111 |
+
opacity = img[:, :, 3:4]
|
112 |
+
img = opacity * img[:, :, :3] + torch.ones(img.shape[0], img.shape[1], 3, device = self.device) * (1 - opacity)
|
113 |
+
img = img[:, :, :3]
|
114 |
+
# Convert img from HWC to NCHW
|
115 |
+
img = img.unsqueeze(0)
|
116 |
+
img = img.permute(0, 3, 1, 2).to(self.device) # NHWC -> NCHW
|
117 |
+
return img
|
118 |
+
|
119 |
+
def get_path(self):
|
120 |
+
points = []
|
121 |
+
self.num_control_points = torch.zeros(self.num_segments, dtype = torch.int32) + (self.control_points_per_seg - 2)
|
122 |
+
p0 = self.inds_normalised[self.strokes_counter] if self.attention_init else (random.random(), random.random())
|
123 |
+
points.append(p0)
|
124 |
+
|
125 |
+
for j in range(self.num_segments):
|
126 |
+
radius = 0.05
|
127 |
+
for k in range(self.control_points_per_seg - 1):
|
128 |
+
p1 = (p0[0] + radius * (random.random() - 0.5), p0[1] + radius * (random.random() - 0.5))
|
129 |
+
points.append(p1)
|
130 |
+
p0 = p1
|
131 |
+
points = torch.tensor(points).to(self.device)
|
132 |
+
points[:, 0] *= self.canvas_width
|
133 |
+
points[:, 1] *= self.canvas_height
|
134 |
+
|
135 |
+
path = pydiffvg.Path(num_control_points = self.num_control_points,
|
136 |
+
points = points,
|
137 |
+
stroke_width = torch.tensor(self.width),
|
138 |
+
is_closed = False)
|
139 |
+
self.strokes_counter += 1
|
140 |
+
return path
|
141 |
+
|
142 |
+
def render_warp(self):
|
143 |
+
if self.opacity_optim:
|
144 |
+
for group in self.shape_groups:
|
145 |
+
group.stroke_color.data[:3].clamp_(0., 0.) # to force black stroke
|
146 |
+
group.stroke_color.data[-1].clamp_(0., 1.) # opacity
|
147 |
+
# group.stroke_color.data[-1] = (group.stroke_color.data[-1] >= self.color_vars_threshold).float()
|
148 |
+
_render = pydiffvg.RenderFunction.apply
|
149 |
+
# uncomment if you want to add random noise
|
150 |
+
if self.add_random_noise:
|
151 |
+
if random.random() > self.noise_thresh:
|
152 |
+
eps = 0.01 * min(self.canvas_width, self.canvas_height)
|
153 |
+
for path in self.shapes:
|
154 |
+
path.points.data.add_(eps * torch.randn_like(path.points))
|
155 |
+
scene_args = pydiffvg.RenderFunction.serialize_scene(\
|
156 |
+
self.canvas_width, self.canvas_height, self.shapes, self.shape_groups)
|
157 |
+
img = _render(self.canvas_width, # width
|
158 |
+
self.canvas_height, # height
|
159 |
+
2, # num_samples_x
|
160 |
+
2, # num_samples_y
|
161 |
+
0, # seed
|
162 |
+
None,
|
163 |
+
*scene_args)
|
164 |
+
return img
|
165 |
+
|
166 |
+
def parameters(self):
|
167 |
+
self.points_vars = []
|
168 |
+
# storkes' location optimization
|
169 |
+
for i, path in enumerate(self.shapes):
|
170 |
+
if self.optimize_flag[i]:
|
171 |
+
path.points.requires_grad = True
|
172 |
+
self.points_vars.append(path.points)
|
173 |
+
return self.points_vars
|
174 |
+
|
175 |
+
def get_points_parans(self):
|
176 |
+
return self.points_vars
|
177 |
+
|
178 |
+
def set_color_parameters(self):
|
179 |
+
# for storkes' color optimization (opacity)
|
180 |
+
self.color_vars = []
|
181 |
+
for i, group in enumerate(self.shape_groups):
|
182 |
+
if self.optimize_flag[i]:
|
183 |
+
group.stroke_color.requires_grad = True
|
184 |
+
self.color_vars.append(group.stroke_color)
|
185 |
+
return self.color_vars
|
186 |
+
|
187 |
+
def get_color_parameters(self):
|
188 |
+
return self.color_vars
|
189 |
+
|
190 |
+
def save_svg(self, output_dir, name):
|
191 |
+
pydiffvg.save_svg('{}/{}.svg'.format(output_dir, name), self.canvas_width, self.canvas_height, self.shapes, self.shape_groups)
|
192 |
+
|
193 |
+
|
194 |
+
def dino_attn(self):
|
195 |
+
patch_size=8 # dino hyperparameter
|
196 |
+
threshold=0.6
|
197 |
+
|
198 |
+
# for dino model
|
199 |
+
mean_imagenet = torch.Tensor([0.485, 0.456, 0.406])[None,:,None,None].to(self.device)
|
200 |
+
std_imagenet = torch.Tensor([0.229, 0.224, 0.225])[None,:,None,None].to(self.device)
|
201 |
+
totens = transforms.Compose([
|
202 |
+
transforms.Resize((self.canvas_height, self.canvas_width)),
|
203 |
+
transforms.ToTensor()
|
204 |
+
])
|
205 |
+
|
206 |
+
dino_model = torch.hub.load('facebookresearch/dino:main', 'dino_vits8').eval().to(self.device)
|
207 |
+
|
208 |
+
self.main_im = Image.open(self.target_path).convert("RGB")
|
209 |
+
main_im_tensor = totens(self.main_im).to(self.device)
|
210 |
+
img = (main_im_tensor.unsqueeze(0) - mean_imagenet) / std_imagenet
|
211 |
+
w_featmap = img.shape[-2] // patch_size
|
212 |
+
h_featmap = img.shape[-1] // patch_size
|
213 |
+
|
214 |
+
with torch.no_grad():
|
215 |
+
attn = dino_model.get_last_selfattention(img).detach().cpu()[0]
|
216 |
+
|
217 |
+
nh = attn.shape[0]
|
218 |
+
attn = attn[:,0,1:].reshape(nh,-1)
|
219 |
+
val, idx = torch.sort(attn)
|
220 |
+
val /= torch.sum(val, dim=1, keepdim=True)
|
221 |
+
cumval = torch.cumsum(val, dim=1)
|
222 |
+
th_attn = cumval > (1 - threshold)
|
223 |
+
idx2 = torch.argsort(idx)
|
224 |
+
for head in range(nh):
|
225 |
+
th_attn[head] = th_attn[head][idx2[head]]
|
226 |
+
th_attn = th_attn.reshape(nh, w_featmap, h_featmap).float()
|
227 |
+
th_attn = nn.functional.interpolate(th_attn.unsqueeze(0), scale_factor=patch_size, mode="nearest")[0].cpu()
|
228 |
+
|
229 |
+
attn = attn.reshape(nh, w_featmap, h_featmap).float()
|
230 |
+
attn = nn.functional.interpolate(attn.unsqueeze(0), scale_factor=patch_size, mode="nearest")[0].cpu()
|
231 |
+
|
232 |
+
return attn
|
233 |
+
|
234 |
+
|
235 |
+
def define_attention_input(self, target_im):
|
236 |
+
model, preprocess = clip.load(self.saliency_clip_model, device=self.device, jit=False)
|
237 |
+
model.eval().to(self.device)
|
238 |
+
data_transforms = transforms.Compose([
|
239 |
+
preprocess.transforms[-1],
|
240 |
+
])
|
241 |
+
self.image_input_attn_clip = data_transforms(target_im).to(self.device)
|
242 |
+
|
243 |
+
|
244 |
+
def clip_attn(self):
|
245 |
+
model, preprocess = clip.load(self.saliency_clip_model, device=self.device, jit=False)
|
246 |
+
model.eval().to(self.device)
|
247 |
+
text_input = clip.tokenize([self.text_target]).to(self.device)
|
248 |
+
|
249 |
+
if "RN" in self.saliency_clip_model:
|
250 |
+
saliency_layer = "layer4"
|
251 |
+
attn_map = gradCAM(
|
252 |
+
model.visual,
|
253 |
+
self.image_input_attn_clip,
|
254 |
+
model.encode_text(text_input).float(),
|
255 |
+
getattr(model.visual, saliency_layer)
|
256 |
+
)
|
257 |
+
attn_map = attn_map.squeeze().detach().cpu().numpy()
|
258 |
+
attn_map = (attn_map - attn_map.min()) / (attn_map.max() - attn_map.min())
|
259 |
+
|
260 |
+
else:
|
261 |
+
# attn_map = interpret(self.image_input_attn_clip, text_input, model, device=self.device, index=0).astype(np.float32)
|
262 |
+
attn_map = interpret(self.image_input_attn_clip, text_input, model, device=self.device)
|
263 |
+
|
264 |
+
del model
|
265 |
+
return attn_map
|
266 |
+
|
267 |
+
def set_attention_map(self):
|
268 |
+
assert self.saliency_model in ["dino", "clip"]
|
269 |
+
if self.saliency_model == "dino":
|
270 |
+
return self.dino_attn()
|
271 |
+
elif self.saliency_model == "clip":
|
272 |
+
return self.clip_attn()
|
273 |
+
|
274 |
+
|
275 |
+
def softmax(self, x, tau=0.2):
|
276 |
+
e_x = np.exp(x / tau)
|
277 |
+
return e_x / e_x.sum()
|
278 |
+
|
279 |
+
def set_inds_clip(self):
|
280 |
+
attn_map = (self.attention_map - self.attention_map.min()) / (self.attention_map.max() - self.attention_map.min())
|
281 |
+
if self.xdog_intersec:
|
282 |
+
xdog = XDoG_()
|
283 |
+
im_xdog = xdog(self.image_input_attn_clip[0].permute(1,2,0).cpu().numpy(), k=10)
|
284 |
+
intersec_map = (1 - im_xdog) * attn_map
|
285 |
+
attn_map = intersec_map
|
286 |
+
|
287 |
+
attn_map_soft = np.copy(attn_map)
|
288 |
+
attn_map_soft[attn_map > 0] = self.softmax(attn_map[attn_map > 0], tau=self.softmax_temp)
|
289 |
+
|
290 |
+
k = self.num_stages * self.num_paths
|
291 |
+
self.inds = np.random.choice(range(attn_map.flatten().shape[0]), size=k, replace=False, p=attn_map_soft.flatten())
|
292 |
+
self.inds = np.array(np.unravel_index(self.inds, attn_map.shape)).T
|
293 |
+
|
294 |
+
self.inds_normalised = np.zeros(self.inds.shape)
|
295 |
+
self.inds_normalised[:, 0] = self.inds[:, 1] / self.canvas_width
|
296 |
+
self.inds_normalised[:, 1] = self.inds[:, 0] / self.canvas_height
|
297 |
+
self.inds_normalised = self.inds_normalised.tolist()
|
298 |
+
return attn_map_soft
|
299 |
+
|
300 |
+
|
301 |
+
|
302 |
+
def set_inds_dino(self):
|
303 |
+
k = max(3, (self.num_stages * self.num_paths) // 6 + 1) # sample top 3 three points from each attention head
|
304 |
+
num_heads = self.attention_map.shape[0]
|
305 |
+
self.inds = np.zeros((k * num_heads, 2))
|
306 |
+
# "thresh" is used for visualisaiton purposes only
|
307 |
+
thresh = torch.zeros(num_heads + 1, self.attention_map.shape[1], self.attention_map.shape[2])
|
308 |
+
softmax = nn.Softmax(dim=1)
|
309 |
+
for i in range(num_heads):
|
310 |
+
# replace "self.attention_map[i]" with "self.attention_map" to get the highest values among
|
311 |
+
# all heads.
|
312 |
+
topk, indices = np.unique(self.attention_map[i].numpy(), return_index=True)
|
313 |
+
topk = topk[::-1][:k]
|
314 |
+
cur_attn_map = self.attention_map[i].numpy()
|
315 |
+
# prob function for uniform sampling
|
316 |
+
prob = cur_attn_map.flatten()
|
317 |
+
prob[prob > topk[-1]] = 1
|
318 |
+
prob[prob <= topk[-1]] = 0
|
319 |
+
prob = prob / prob.sum()
|
320 |
+
thresh[i] = torch.Tensor(prob.reshape(cur_attn_map.shape))
|
321 |
+
|
322 |
+
# choose k pixels from each head
|
323 |
+
inds = np.random.choice(range(cur_attn_map.flatten().shape[0]), size=k, replace=False, p=prob)
|
324 |
+
inds = np.unravel_index(inds, cur_attn_map.shape)
|
325 |
+
self.inds[i * k: i * k + k, 0] = inds[0]
|
326 |
+
self.inds[i * k: i * k + k, 1] = inds[1]
|
327 |
+
|
328 |
+
# for visualisaiton
|
329 |
+
sum_attn = self.attention_map.sum(0).numpy()
|
330 |
+
mask = np.zeros(sum_attn.shape)
|
331 |
+
mask[thresh[:-1].sum(0) > 0] = 1
|
332 |
+
sum_attn = sum_attn * mask
|
333 |
+
sum_attn = sum_attn / sum_attn.sum()
|
334 |
+
thresh[-1] = torch.Tensor(sum_attn)
|
335 |
+
|
336 |
+
# sample num_paths from the chosen pixels.
|
337 |
+
prob_sum = sum_attn[self.inds[:,0].astype(np.int), self.inds[:,1].astype(np.int)]
|
338 |
+
prob_sum = prob_sum / prob_sum.sum()
|
339 |
+
new_inds = []
|
340 |
+
for i in range(self.num_stages):
|
341 |
+
new_inds.extend(np.random.choice(range(self.inds.shape[0]), size=self.num_paths, replace=False, p=prob_sum))
|
342 |
+
self.inds = self.inds[new_inds]
|
343 |
+
print("self.inds",self.inds.shape)
|
344 |
+
|
345 |
+
self.inds_normalised = np.zeros(self.inds.shape)
|
346 |
+
self.inds_normalised[:, 0] = self.inds[:, 1] / self.canvas_width
|
347 |
+
self.inds_normalised[:, 1] = self.inds[:, 0] / self.canvas_height
|
348 |
+
self.inds_normalised = self.inds_normalised.tolist()
|
349 |
+
return thresh
|
350 |
+
|
351 |
+
def set_attention_threshold_map(self):
|
352 |
+
assert self.saliency_model in ["dino", "clip"]
|
353 |
+
if self.saliency_model == "dino":
|
354 |
+
return self.set_inds_dino()
|
355 |
+
elif self.saliency_model == "clip":
|
356 |
+
return self.set_inds_clip()
|
357 |
+
|
358 |
+
|
359 |
+
def get_attn(self):
|
360 |
+
return self.attention_map
|
361 |
+
|
362 |
+
def get_thresh(self):
|
363 |
+
return self.thresh
|
364 |
+
|
365 |
+
def get_inds(self):
|
366 |
+
return self.inds
|
367 |
+
|
368 |
+
def get_mask(self):
|
369 |
+
return self.mask
|
370 |
+
|
371 |
+
def set_random_noise(self, epoch):
|
372 |
+
if epoch % self.args.save_interval == 0:
|
373 |
+
self.add_random_noise = False
|
374 |
+
else:
|
375 |
+
self.add_random_noise = "noise" in self.args.augemntations
|
376 |
+
|
377 |
+
class PainterOptimizer:
|
378 |
+
def __init__(self, args, renderer):
|
379 |
+
self.renderer = renderer
|
380 |
+
self.points_lr = args.lr
|
381 |
+
self.color_lr = args.color_lr
|
382 |
+
self.args = args
|
383 |
+
self.optim_color = args.force_sparse
|
384 |
+
|
385 |
+
def init_optimizers(self):
|
386 |
+
self.points_optim = torch.optim.Adam(self.renderer.parameters(), lr=self.points_lr)
|
387 |
+
if self.optim_color:
|
388 |
+
self.color_optim = torch.optim.Adam(self.renderer.set_color_parameters(), lr=self.color_lr)
|
389 |
+
|
390 |
+
def update_lr(self, counter):
|
391 |
+
new_lr = utils.get_epoch_lr(counter, self.args)
|
392 |
+
for param_group in self.points_optim.param_groups:
|
393 |
+
param_group["lr"] = new_lr
|
394 |
+
|
395 |
+
def zero_grad_(self):
|
396 |
+
self.points_optim.zero_grad()
|
397 |
+
if self.optim_color:
|
398 |
+
self.color_optim.zero_grad()
|
399 |
+
|
400 |
+
def step_(self):
|
401 |
+
self.points_optim.step()
|
402 |
+
if self.optim_color:
|
403 |
+
self.color_optim.step()
|
404 |
+
|
405 |
+
def get_lr(self):
|
406 |
+
return self.points_optim.param_groups[0]['lr']
|
407 |
+
|
408 |
+
|
409 |
+
class Hook:
|
410 |
+
"""Attaches to a module and records its activations and gradients."""
|
411 |
+
|
412 |
+
def __init__(self, module: nn.Module):
|
413 |
+
self.data = None
|
414 |
+
self.hook = module.register_forward_hook(self.save_grad)
|
415 |
+
|
416 |
+
def save_grad(self, module, input, output):
|
417 |
+
self.data = output
|
418 |
+
output.requires_grad_(True)
|
419 |
+
output.retain_grad()
|
420 |
+
|
421 |
+
def __enter__(self):
|
422 |
+
return self
|
423 |
+
|
424 |
+
def __exit__(self, exc_type, exc_value, exc_traceback):
|
425 |
+
self.hook.remove()
|
426 |
+
|
427 |
+
@property
|
428 |
+
def activation(self) -> torch.Tensor:
|
429 |
+
return self.data
|
430 |
+
|
431 |
+
@property
|
432 |
+
def gradient(self) -> torch.Tensor:
|
433 |
+
return self.data.grad
|
434 |
+
|
435 |
+
|
436 |
+
|
437 |
+
|
438 |
+
def interpret(image, texts, model, device):
|
439 |
+
images = image.repeat(1, 1, 1, 1)
|
440 |
+
res = model.encode_image(images)
|
441 |
+
model.zero_grad()
|
442 |
+
image_attn_blocks = list(dict(model.visual.transformer.resblocks.named_children()).values())
|
443 |
+
num_tokens = image_attn_blocks[0].attn_probs.shape[-1]
|
444 |
+
R = torch.eye(num_tokens, num_tokens, dtype=image_attn_blocks[0].attn_probs.dtype).to(device)
|
445 |
+
R = R.unsqueeze(0).expand(1, num_tokens, num_tokens)
|
446 |
+
cams = [] # there are 12 attention blocks
|
447 |
+
for i, blk in enumerate(image_attn_blocks):
|
448 |
+
cam = blk.attn_probs.detach() #attn_probs shape is 12, 50, 50
|
449 |
+
# each patch is 7x7 so we have 49 pixels + 1 for positional encoding
|
450 |
+
cam = cam.reshape(1, -1, cam.shape[-1], cam.shape[-1])
|
451 |
+
cam = cam.clamp(min=0)
|
452 |
+
cam = cam.clamp(min=0).mean(dim=1) # mean of the 12 something
|
453 |
+
cams.append(cam)
|
454 |
+
R = R + torch.bmm(cam, R)
|
455 |
+
|
456 |
+
cams_avg = torch.cat(cams) # 12, 50, 50
|
457 |
+
cams_avg = cams_avg[:, 0, 1:] # 12, 1, 49
|
458 |
+
image_relevance = cams_avg.mean(dim=0).unsqueeze(0)
|
459 |
+
image_relevance = image_relevance.reshape(1, 1, 7, 7)
|
460 |
+
image_relevance = torch.nn.functional.interpolate(image_relevance, size=224, mode='bicubic')
|
461 |
+
image_relevance = image_relevance.reshape(224, 224).data.cpu().numpy().astype(np.float32)
|
462 |
+
image_relevance = (image_relevance - image_relevance.min()) / (image_relevance.max() - image_relevance.min())
|
463 |
+
return image_relevance
|
464 |
+
|
465 |
+
|
466 |
+
# Reference: https://arxiv.org/abs/1610.02391
|
467 |
+
def gradCAM(
|
468 |
+
model: nn.Module,
|
469 |
+
input: torch.Tensor,
|
470 |
+
target: torch.Tensor,
|
471 |
+
layer: nn.Module
|
472 |
+
) -> torch.Tensor:
|
473 |
+
# Zero out any gradients at the input.
|
474 |
+
if input.grad is not None:
|
475 |
+
input.grad.data.zero_()
|
476 |
+
|
477 |
+
# Disable gradient settings.
|
478 |
+
requires_grad = {}
|
479 |
+
for name, param in model.named_parameters():
|
480 |
+
requires_grad[name] = param.requires_grad
|
481 |
+
param.requires_grad_(False)
|
482 |
+
|
483 |
+
# Attach a hook to the model at the desired layer.
|
484 |
+
assert isinstance(layer, nn.Module)
|
485 |
+
with Hook(layer) as hook:
|
486 |
+
# Do a forward and backward pass.
|
487 |
+
output = model(input)
|
488 |
+
output.backward(target)
|
489 |
+
|
490 |
+
grad = hook.gradient.float()
|
491 |
+
act = hook.activation.float()
|
492 |
+
|
493 |
+
# Global average pool gradient across spatial dimension
|
494 |
+
# to obtain importance weights.
|
495 |
+
alpha = grad.mean(dim=(2, 3), keepdim=True)
|
496 |
+
# Weighted combination of activation maps over channel
|
497 |
+
# dimension.
|
498 |
+
gradcam = torch.sum(act * alpha, dim=1, keepdim=True)
|
499 |
+
# We only want neurons with positive influence so we
|
500 |
+
# clamp any negative ones.
|
501 |
+
gradcam = torch.clamp(gradcam, min=0)
|
502 |
+
|
503 |
+
# Resize gradcam to input resolution.
|
504 |
+
gradcam = F.interpolate(
|
505 |
+
gradcam,
|
506 |
+
input.shape[2:],
|
507 |
+
mode='bicubic',
|
508 |
+
align_corners=False)
|
509 |
+
|
510 |
+
# Restore gradient settings.
|
511 |
+
for name, param in model.named_parameters():
|
512 |
+
param.requires_grad_(requires_grad[name])
|
513 |
+
|
514 |
+
return gradcam
|
515 |
+
|
516 |
+
|
517 |
+
class XDoG_(object):
|
518 |
+
def __init__(self):
|
519 |
+
super(XDoG_, self).__init__()
|
520 |
+
self.gamma=0.98
|
521 |
+
self.phi=200
|
522 |
+
self.eps=-0.1
|
523 |
+
self.sigma=0.8
|
524 |
+
self.binarize=True
|
525 |
+
|
526 |
+
def __call__(self, im, k=10):
|
527 |
+
if im.shape[2] == 3:
|
528 |
+
im = rgb2gray(im)
|
529 |
+
imf1 = gaussian_filter(im, self.sigma)
|
530 |
+
imf2 = gaussian_filter(im, self.sigma * k)
|
531 |
+
imdiff = imf1 - self.gamma * imf2
|
532 |
+
imdiff = (imdiff < self.eps) * 1.0 + (imdiff >= self.eps) * (1.0 + np.tanh(self.phi * imdiff))
|
533 |
+
imdiff -= imdiff.min()
|
534 |
+
imdiff /= imdiff.max()
|
535 |
+
if self.binarize:
|
536 |
+
th = threshold_otsu(imdiff)
|
537 |
+
imdiff = imdiff >= th
|
538 |
+
imdiff = imdiff.astype('float32')
|
539 |
+
return imdiff
|
painterly_rendering.py
ADDED
@@ -0,0 +1,190 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import warnings
|
2 |
+
|
3 |
+
warnings.filterwarnings('ignore')
|
4 |
+
warnings.simplefilter('ignore')
|
5 |
+
|
6 |
+
import argparse
|
7 |
+
import math
|
8 |
+
import os
|
9 |
+
import sys
|
10 |
+
import time
|
11 |
+
import traceback
|
12 |
+
|
13 |
+
import numpy as np
|
14 |
+
import PIL
|
15 |
+
import torch
|
16 |
+
import torch.nn as nn
|
17 |
+
import torch.nn.functional as F
|
18 |
+
import wandb
|
19 |
+
from PIL import Image
|
20 |
+
from torchvision import models, transforms
|
21 |
+
from tqdm.auto import tqdm, trange
|
22 |
+
|
23 |
+
import config
|
24 |
+
import sketch_utils as utils
|
25 |
+
from models.loss import Loss
|
26 |
+
from models.painter_params import Painter, PainterOptimizer
|
27 |
+
from IPython.display import display, SVG
|
28 |
+
|
29 |
+
|
30 |
+
def load_renderer(args, target_im=None, mask=None):
|
31 |
+
renderer = Painter(num_strokes=args.num_paths, args=args,
|
32 |
+
num_segments=args.num_segments,
|
33 |
+
imsize=args.image_scale,
|
34 |
+
device=args.device,
|
35 |
+
target_im=target_im,
|
36 |
+
mask=mask)
|
37 |
+
renderer = renderer.to(args.device)
|
38 |
+
return renderer
|
39 |
+
|
40 |
+
|
41 |
+
def get_target(args):
|
42 |
+
target = Image.open(args.target)
|
43 |
+
if target.mode == "RGBA":
|
44 |
+
# Create a white rgba background
|
45 |
+
new_image = Image.new("RGBA", target.size, "WHITE")
|
46 |
+
# Paste the image on the background.
|
47 |
+
new_image.paste(target, (0, 0), target)
|
48 |
+
target = new_image
|
49 |
+
target = target.convert("RGB")
|
50 |
+
masked_im, mask = utils.get_mask_u2net(args, target)
|
51 |
+
if args.mask_object:
|
52 |
+
target = masked_im
|
53 |
+
if args.fix_scale:
|
54 |
+
target = utils.fix_image_scale(target)
|
55 |
+
|
56 |
+
transforms_ = []
|
57 |
+
if target.size[0] != target.size[1]:
|
58 |
+
transforms_.append(transforms.Resize(
|
59 |
+
(args.image_scale, args.image_scale), interpolation=PIL.Image.BICUBIC))
|
60 |
+
else:
|
61 |
+
transforms_.append(transforms.Resize(
|
62 |
+
args.image_scale, interpolation=PIL.Image.BICUBIC))
|
63 |
+
transforms_.append(transforms.CenterCrop(args.image_scale))
|
64 |
+
transforms_.append(transforms.ToTensor())
|
65 |
+
data_transforms = transforms.Compose(transforms_)
|
66 |
+
target_ = data_transforms(target).unsqueeze(0).to(args.device)
|
67 |
+
return target_, mask
|
68 |
+
|
69 |
+
|
70 |
+
def main(args):
|
71 |
+
loss_func = Loss(args)
|
72 |
+
inputs, mask = get_target(args)
|
73 |
+
utils.log_input(args.use_wandb, 0, inputs, args.output_dir)
|
74 |
+
renderer = load_renderer(args, inputs, mask)
|
75 |
+
|
76 |
+
optimizer = PainterOptimizer(args, renderer)
|
77 |
+
counter = 0
|
78 |
+
configs_to_save = {"loss_eval": []}
|
79 |
+
best_loss, best_fc_loss = 100, 100
|
80 |
+
best_iter, best_iter_fc = 0, 0
|
81 |
+
min_delta = 1e-5
|
82 |
+
terminate = False
|
83 |
+
|
84 |
+
renderer.set_random_noise(0)
|
85 |
+
img = renderer.init_image(stage=0)
|
86 |
+
optimizer.init_optimizers()
|
87 |
+
|
88 |
+
# not using tdqm for jupyter demo
|
89 |
+
if args.display:
|
90 |
+
epoch_range = range(args.num_iter)
|
91 |
+
else:
|
92 |
+
epoch_range = tqdm(range(args.num_iter))
|
93 |
+
|
94 |
+
for epoch in epoch_range:
|
95 |
+
if not args.display:
|
96 |
+
epoch_range.refresh()
|
97 |
+
renderer.set_random_noise(epoch)
|
98 |
+
if args.lr_scheduler:
|
99 |
+
optimizer.update_lr(counter)
|
100 |
+
|
101 |
+
start = time.time()
|
102 |
+
optimizer.zero_grad_()
|
103 |
+
sketches = renderer.get_image().to(args.device)
|
104 |
+
losses_dict = loss_func(sketches, inputs.detach(
|
105 |
+
), renderer.get_color_parameters(), renderer, counter, optimizer)
|
106 |
+
loss = sum(list(losses_dict.values()))
|
107 |
+
loss.backward()
|
108 |
+
optimizer.step_()
|
109 |
+
if epoch % args.save_interval == 0:
|
110 |
+
utils.plot_batch(inputs, sketches, f"{args.output_dir}/jpg_logs", counter,
|
111 |
+
use_wandb=args.use_wandb, title=f"iter{epoch}.jpg")
|
112 |
+
renderer.save_svg(
|
113 |
+
f"{args.output_dir}/svg_logs", f"svg_iter{epoch}")
|
114 |
+
if epoch % args.eval_interval == 0:
|
115 |
+
with torch.no_grad():
|
116 |
+
losses_dict_eval = loss_func(sketches, inputs, renderer.get_color_parameters(
|
117 |
+
), renderer.get_points_parans(), counter, optimizer, mode="eval")
|
118 |
+
loss_eval = sum(list(losses_dict_eval.values()))
|
119 |
+
configs_to_save["loss_eval"].append(loss_eval.item())
|
120 |
+
for k in losses_dict_eval.keys():
|
121 |
+
if k not in configs_to_save.keys():
|
122 |
+
configs_to_save[k] = []
|
123 |
+
configs_to_save[k].append(losses_dict_eval[k].item())
|
124 |
+
if args.clip_fc_loss_weight:
|
125 |
+
if losses_dict_eval["fc"].item() < best_fc_loss:
|
126 |
+
best_fc_loss = losses_dict_eval["fc"].item(
|
127 |
+
) / args.clip_fc_loss_weight
|
128 |
+
best_iter_fc = epoch
|
129 |
+
# print(
|
130 |
+
# f"eval iter[{epoch}/{args.num_iter}] loss[{loss.item()}] time[{time.time() - start}]")
|
131 |
+
|
132 |
+
cur_delta = loss_eval.item() - best_loss
|
133 |
+
if abs(cur_delta) > min_delta:
|
134 |
+
if cur_delta < 0:
|
135 |
+
best_loss = loss_eval.item()
|
136 |
+
best_iter = epoch
|
137 |
+
terminate = False
|
138 |
+
utils.plot_batch(
|
139 |
+
inputs, sketches, args.output_dir, counter, use_wandb=args.use_wandb, title="best_iter.jpg")
|
140 |
+
renderer.save_svg(args.output_dir, "best_iter")
|
141 |
+
|
142 |
+
if args.use_wandb:
|
143 |
+
wandb.run.summary["best_loss"] = best_loss
|
144 |
+
wandb.run.summary["best_loss_fc"] = best_fc_loss
|
145 |
+
wandb_dict = {"delta": cur_delta,
|
146 |
+
"loss_eval": loss_eval.item()}
|
147 |
+
for k in losses_dict_eval.keys():
|
148 |
+
wandb_dict[k + "_eval"] = losses_dict_eval[k].item()
|
149 |
+
wandb.log(wandb_dict, step=counter)
|
150 |
+
|
151 |
+
if abs(cur_delta) <= min_delta:
|
152 |
+
if terminate:
|
153 |
+
break
|
154 |
+
terminate = True
|
155 |
+
|
156 |
+
if counter == 0 and args.attention_init:
|
157 |
+
utils.plot_atten(renderer.get_attn(), renderer.get_thresh(), inputs, renderer.get_inds(),
|
158 |
+
args.use_wandb, "{}/{}.jpg".format(
|
159 |
+
args.output_dir, "attention_map"),
|
160 |
+
args.saliency_model, args.display_logs)
|
161 |
+
|
162 |
+
if args.use_wandb:
|
163 |
+
wandb_dict = {"loss": loss.item(), "lr": optimizer.get_lr()}
|
164 |
+
for k in losses_dict.keys():
|
165 |
+
wandb_dict[k] = losses_dict[k].item()
|
166 |
+
wandb.log(wandb_dict, step=counter)
|
167 |
+
|
168 |
+
counter += 1
|
169 |
+
|
170 |
+
renderer.save_svg(args.output_dir, "final_svg")
|
171 |
+
path_svg = os.path.join(args.output_dir, "best_iter.svg")
|
172 |
+
utils.log_sketch_summary_final(
|
173 |
+
path_svg, args.use_wandb, args.device, best_iter, best_loss, "best total")
|
174 |
+
|
175 |
+
return configs_to_save
|
176 |
+
|
177 |
+
if __name__ == "__main__":
|
178 |
+
args = config.parse_arguments()
|
179 |
+
final_config = vars(args)
|
180 |
+
try:
|
181 |
+
configs_to_save = main(args)
|
182 |
+
except BaseException as err:
|
183 |
+
print(f"Unexpected error occurred:\n {err}")
|
184 |
+
print(traceback.format_exc())
|
185 |
+
sys.exit(1)
|
186 |
+
for k in configs_to_save.keys():
|
187 |
+
final_config[k] = configs_to_save[k]
|
188 |
+
np.save(f"{args.output_dir}/config.npy", final_config)
|
189 |
+
if args.use_wandb:
|
190 |
+
wandb.finish()
|
predict.py
ADDED
@@ -0,0 +1,317 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# sudo cog push r8.im/yael-vinker/clipasso
|
2 |
+
|
3 |
+
# Prediction interface for Cog ⚙️
|
4 |
+
# https://github.com/replicate/cog/blob/main/docs/python.md
|
5 |
+
|
6 |
+
import warnings
|
7 |
+
|
8 |
+
warnings.filterwarnings('ignore')
|
9 |
+
warnings.simplefilter('ignore')
|
10 |
+
|
11 |
+
from cog import BasePredictor, Input, Path
|
12 |
+
import subprocess as sp
|
13 |
+
import os
|
14 |
+
import re
|
15 |
+
|
16 |
+
import imageio
|
17 |
+
import matplotlib.pyplot as plt
|
18 |
+
import numpy as np
|
19 |
+
import pydiffvg
|
20 |
+
import torch
|
21 |
+
from PIL import Image
|
22 |
+
import multiprocessing as mp
|
23 |
+
from shutil import copyfile
|
24 |
+
|
25 |
+
import argparse
|
26 |
+
import math
|
27 |
+
import sys
|
28 |
+
import time
|
29 |
+
import traceback
|
30 |
+
|
31 |
+
import PIL
|
32 |
+
import torch.nn as nn
|
33 |
+
import torch.nn.functional as F
|
34 |
+
import wandb
|
35 |
+
from torchvision import models, transforms
|
36 |
+
from tqdm import tqdm
|
37 |
+
|
38 |
+
import config
|
39 |
+
import sketch_utils as utils
|
40 |
+
from models.loss import Loss
|
41 |
+
from models.painter_params import Painter, PainterOptimizer
|
42 |
+
|
43 |
+
|
44 |
+
class Predictor(BasePredictor):
|
45 |
+
def setup(self):
|
46 |
+
"""Load the model into memory to make running multiple predictions efficient"""
|
47 |
+
self.num_iter = 2001
|
48 |
+
self.save_interval = 100
|
49 |
+
self.num_sketches = 3
|
50 |
+
self.use_gpu = True
|
51 |
+
|
52 |
+
def predict(
|
53 |
+
self,
|
54 |
+
target_image: Path = Input(description="Input image (square, without background)"),
|
55 |
+
num_strokes: int = Input(description="The number of strokes used to create the sketch, which determines the level of abstraction",default=16),
|
56 |
+
trials: int = Input(description="It is recommended to use 3 trials to recieve the best sketch, but it might be slower",default=3),
|
57 |
+
mask_object: int = Input(description="It is recommended to use images without a background, however, if your image contains a background, you can mask it out by using this flag with 1 as an argument",default=0),
|
58 |
+
fix_scale: int = Input(description="If your image is not squared, it might be cut off, it is recommended to use this flag with 1 as input to automatically fix the scale without cutting the image",default=0),
|
59 |
+
) -> Path:
|
60 |
+
|
61 |
+
self.num_sketches = trials
|
62 |
+
target_image_name = os.path.basename(str(target_image))
|
63 |
+
|
64 |
+
multiprocess = False
|
65 |
+
abs_path = os.path.abspath(os.getcwd())
|
66 |
+
|
67 |
+
target = str(target_image)
|
68 |
+
assert os.path.isfile(target), f"{target} does not exists!"
|
69 |
+
|
70 |
+
test_name = os.path.splitext(target_image_name)[0]
|
71 |
+
output_dir = f"{abs_path}/output_sketches/{test_name}/"
|
72 |
+
if not os.path.exists(output_dir):
|
73 |
+
os.makedirs(output_dir)
|
74 |
+
|
75 |
+
print("=" * 50)
|
76 |
+
print(f"Processing [{target_image_name}] ...")
|
77 |
+
print(f"Results will be saved to \n[{output_dir}] ...")
|
78 |
+
print("=" * 50)
|
79 |
+
|
80 |
+
if not torch.cuda.is_available():
|
81 |
+
self.use_gpu = False
|
82 |
+
print("CUDA is not configured with GPU, running with CPU instead.")
|
83 |
+
print("Note that this will be very slow, it is recommended to use colab.")
|
84 |
+
print(f"GPU: {self.use_gpu}")
|
85 |
+
seeds = list(range(0, self.num_sketches * 1000, 1000))
|
86 |
+
|
87 |
+
losses_all = {}
|
88 |
+
|
89 |
+
for seed in seeds:
|
90 |
+
wandb_name = f"{test_name}_{num_strokes}strokes_seed{seed}"
|
91 |
+
sp.run(["python", "config.py", target,
|
92 |
+
"--num_paths", str(num_strokes),
|
93 |
+
"--output_dir", output_dir,
|
94 |
+
"--wandb_name", wandb_name,
|
95 |
+
"--num_iter", str(self.num_iter),
|
96 |
+
"--save_interval", str(self.save_interval),
|
97 |
+
"--seed", str(seed),
|
98 |
+
"--use_gpu", str(int(self.use_gpu)),
|
99 |
+
"--fix_scale", str(fix_scale),
|
100 |
+
"--mask_object", str(mask_object),
|
101 |
+
"--mask_object_attention", str(
|
102 |
+
mask_object),
|
103 |
+
"--display_logs", str(int(0))])
|
104 |
+
config_init = np.load(f"{output_dir}/{wandb_name}/config_init.npy", allow_pickle=True)[()]
|
105 |
+
args = Args(config_init)
|
106 |
+
args.cog_display = True
|
107 |
+
|
108 |
+
final_config = vars(args)
|
109 |
+
try:
|
110 |
+
configs_to_save = main(args)
|
111 |
+
except BaseException as err:
|
112 |
+
print(f"Unexpected error occurred:\n {err}")
|
113 |
+
print(traceback.format_exc())
|
114 |
+
sys.exit(1)
|
115 |
+
for k in configs_to_save.keys():
|
116 |
+
final_config[k] = configs_to_save[k]
|
117 |
+
np.save(f"{args.output_dir}/config.npy", final_config)
|
118 |
+
if args.use_wandb:
|
119 |
+
wandb.finish()
|
120 |
+
|
121 |
+
config = np.load(f"{output_dir}/{wandb_name}/config.npy",
|
122 |
+
allow_pickle=True)[()]
|
123 |
+
loss_eval = np.array(config['loss_eval'])
|
124 |
+
inds = np.argsort(loss_eval)
|
125 |
+
losses_all[wandb_name] = loss_eval[inds][0]
|
126 |
+
# return Path(f"{output_dir}/{wandb_name}/best_iter.svg")
|
127 |
+
|
128 |
+
|
129 |
+
sorted_final = dict(sorted(losses_all.items(), key=lambda item: item[1]))
|
130 |
+
copyfile(f"{output_dir}/{list(sorted_final.keys())[0]}/best_iter.svg",
|
131 |
+
f"{output_dir}/{list(sorted_final.keys())[0]}_best.svg")
|
132 |
+
target_path = f"{abs_path}/target_images/{target_image_name}"
|
133 |
+
svg_files = os.listdir(output_dir)
|
134 |
+
svg_files = [f for f in svg_files if "best.svg" in f]
|
135 |
+
svg_output_path = f"{output_dir}/{svg_files[0]}"
|
136 |
+
sketch_res = read_svg(svg_output_path, multiply=True).cpu().numpy()
|
137 |
+
sketch_res = Image.fromarray((sketch_res * 255).astype('uint8'), 'RGB')
|
138 |
+
sketch_res.save(f"{abs_path}/output_sketches/sketch.png")
|
139 |
+
return Path(svg_output_path)
|
140 |
+
|
141 |
+
|
142 |
+
class Args():
|
143 |
+
def __init__(self, config):
|
144 |
+
for k in config.keys():
|
145 |
+
setattr(self, k, config[k])
|
146 |
+
|
147 |
+
|
148 |
+
def load_renderer(args, target_im=None, mask=None):
|
149 |
+
renderer = Painter(num_strokes=args.num_paths, args=args,
|
150 |
+
num_segments=args.num_segments,
|
151 |
+
imsize=args.image_scale,
|
152 |
+
device=args.device,
|
153 |
+
target_im=target_im,
|
154 |
+
mask=mask)
|
155 |
+
renderer = renderer.to(args.device)
|
156 |
+
return renderer
|
157 |
+
|
158 |
+
|
159 |
+
def get_target(args):
|
160 |
+
target = Image.open(args.target)
|
161 |
+
if target.mode == "RGBA":
|
162 |
+
# Create a white rgba background
|
163 |
+
new_image = Image.new("RGBA", target.size, "WHITE")
|
164 |
+
# Paste the image on the background.
|
165 |
+
new_image.paste(target, (0, 0), target)
|
166 |
+
target = new_image
|
167 |
+
target = target.convert("RGB")
|
168 |
+
masked_im, mask = utils.get_mask_u2net(args, target)
|
169 |
+
if args.mask_object:
|
170 |
+
target = masked_im
|
171 |
+
if args.fix_scale:
|
172 |
+
target = utils.fix_image_scale(target)
|
173 |
+
|
174 |
+
transforms_ = []
|
175 |
+
if target.size[0] != target.size[1]:
|
176 |
+
transforms_.append(transforms.Resize(
|
177 |
+
(args.image_scale, args.image_scale), interpolation=PIL.Image.BICUBIC))
|
178 |
+
else:
|
179 |
+
transforms_.append(transforms.Resize(
|
180 |
+
args.image_scale, interpolation=PIL.Image.BICUBIC))
|
181 |
+
transforms_.append(transforms.CenterCrop(args.image_scale))
|
182 |
+
transforms_.append(transforms.ToTensor())
|
183 |
+
data_transforms = transforms.Compose(transforms_)
|
184 |
+
target_ = data_transforms(target).unsqueeze(0).to(args.device)
|
185 |
+
return target_, mask
|
186 |
+
|
187 |
+
|
188 |
+
def main(args):
|
189 |
+
loss_func = Loss(args)
|
190 |
+
inputs, mask = get_target(args)
|
191 |
+
utils.log_input(args.use_wandb, 0, inputs, args.output_dir)
|
192 |
+
renderer = load_renderer(args, inputs, mask)
|
193 |
+
|
194 |
+
optimizer = PainterOptimizer(args, renderer)
|
195 |
+
counter = 0
|
196 |
+
configs_to_save = {"loss_eval": []}
|
197 |
+
best_loss, best_fc_loss = 100, 100
|
198 |
+
best_iter, best_iter_fc = 0, 0
|
199 |
+
min_delta = 1e-5
|
200 |
+
terminate = False
|
201 |
+
|
202 |
+
renderer.set_random_noise(0)
|
203 |
+
img = renderer.init_image(stage=0)
|
204 |
+
optimizer.init_optimizers()
|
205 |
+
|
206 |
+
for epoch in tqdm(range(args.num_iter)):
|
207 |
+
renderer.set_random_noise(epoch)
|
208 |
+
if args.lr_scheduler:
|
209 |
+
optimizer.update_lr(counter)
|
210 |
+
|
211 |
+
start = time.time()
|
212 |
+
optimizer.zero_grad_()
|
213 |
+
sketches = renderer.get_image().to(args.device)
|
214 |
+
losses_dict = loss_func(sketches, inputs.detach(
|
215 |
+
), renderer.get_color_parameters(), renderer, counter, optimizer)
|
216 |
+
loss = sum(list(losses_dict.values()))
|
217 |
+
loss.backward()
|
218 |
+
optimizer.step_()
|
219 |
+
if epoch % args.save_interval == 0:
|
220 |
+
utils.plot_batch(inputs, sketches, f"{args.output_dir}/jpg_logs", counter,
|
221 |
+
use_wandb=args.use_wandb, title=f"iter{epoch}.jpg")
|
222 |
+
renderer.save_svg(
|
223 |
+
f"{args.output_dir}/svg_logs", f"svg_iter{epoch}")
|
224 |
+
# if args.cog_display:
|
225 |
+
# yield Path(f"{args.output_dir}/svg_logs/svg_iter{epoch}.svg")
|
226 |
+
|
227 |
+
|
228 |
+
if epoch % args.eval_interval == 0:
|
229 |
+
with torch.no_grad():
|
230 |
+
losses_dict_eval = loss_func(sketches, inputs, renderer.get_color_parameters(
|
231 |
+
), renderer.get_points_parans(), counter, optimizer, mode="eval")
|
232 |
+
loss_eval = sum(list(losses_dict_eval.values()))
|
233 |
+
configs_to_save["loss_eval"].append(loss_eval.item())
|
234 |
+
for k in losses_dict_eval.keys():
|
235 |
+
if k not in configs_to_save.keys():
|
236 |
+
configs_to_save[k] = []
|
237 |
+
configs_to_save[k].append(losses_dict_eval[k].item())
|
238 |
+
if args.clip_fc_loss_weight:
|
239 |
+
if losses_dict_eval["fc"].item() < best_fc_loss:
|
240 |
+
best_fc_loss = losses_dict_eval["fc"].item(
|
241 |
+
) / args.clip_fc_loss_weight
|
242 |
+
best_iter_fc = epoch
|
243 |
+
# print(
|
244 |
+
# f"eval iter[{epoch}/{args.num_iter}] loss[{loss.item()}] time[{time.time() - start}]")
|
245 |
+
|
246 |
+
cur_delta = loss_eval.item() - best_loss
|
247 |
+
if abs(cur_delta) > min_delta:
|
248 |
+
if cur_delta < 0:
|
249 |
+
best_loss = loss_eval.item()
|
250 |
+
best_iter = epoch
|
251 |
+
terminate = False
|
252 |
+
utils.plot_batch(
|
253 |
+
inputs, sketches, args.output_dir, counter, use_wandb=args.use_wandb, title="best_iter.jpg")
|
254 |
+
renderer.save_svg(args.output_dir, "best_iter")
|
255 |
+
|
256 |
+
if args.use_wandb:
|
257 |
+
wandb.run.summary["best_loss"] = best_loss
|
258 |
+
wandb.run.summary["best_loss_fc"] = best_fc_loss
|
259 |
+
wandb_dict = {"delta": cur_delta,
|
260 |
+
"loss_eval": loss_eval.item()}
|
261 |
+
for k in losses_dict_eval.keys():
|
262 |
+
wandb_dict[k + "_eval"] = losses_dict_eval[k].item()
|
263 |
+
wandb.log(wandb_dict, step=counter)
|
264 |
+
|
265 |
+
if abs(cur_delta) <= min_delta:
|
266 |
+
if terminate:
|
267 |
+
break
|
268 |
+
terminate = True
|
269 |
+
|
270 |
+
if counter == 0 and args.attention_init:
|
271 |
+
utils.plot_atten(renderer.get_attn(), renderer.get_thresh(), inputs, renderer.get_inds(),
|
272 |
+
args.use_wandb, "{}/{}.jpg".format(
|
273 |
+
args.output_dir, "attention_map"),
|
274 |
+
args.saliency_model, args.display_logs)
|
275 |
+
|
276 |
+
if args.use_wandb:
|
277 |
+
wandb_dict = {"loss": loss.item(), "lr": optimizer.get_lr()}
|
278 |
+
for k in losses_dict.keys():
|
279 |
+
wandb_dict[k] = losses_dict[k].item()
|
280 |
+
wandb.log(wandb_dict, step=counter)
|
281 |
+
|
282 |
+
counter += 1
|
283 |
+
|
284 |
+
renderer.save_svg(args.output_dir, "final_svg")
|
285 |
+
path_svg = os.path.join(args.output_dir, "best_iter.svg")
|
286 |
+
utils.log_sketch_summary_final(
|
287 |
+
path_svg, args.use_wandb, args.device, best_iter, best_loss, "best total")
|
288 |
+
|
289 |
+
return configs_to_save
|
290 |
+
|
291 |
+
|
292 |
+
def read_svg(path_svg, multiply=False):
|
293 |
+
device = torch.device("cuda" if (
|
294 |
+
torch.cuda.is_available() and torch.cuda.device_count() > 0) else "cpu")
|
295 |
+
canvas_width, canvas_height, shapes, shape_groups = pydiffvg.svg_to_scene(
|
296 |
+
path_svg)
|
297 |
+
if multiply:
|
298 |
+
canvas_width *= 2
|
299 |
+
canvas_height *= 2
|
300 |
+
for path in shapes:
|
301 |
+
path.points *= 2
|
302 |
+
path.stroke_width *= 2
|
303 |
+
_render = pydiffvg.RenderFunction.apply
|
304 |
+
scene_args = pydiffvg.RenderFunction.serialize_scene(
|
305 |
+
canvas_width, canvas_height, shapes, shape_groups)
|
306 |
+
img = _render(canvas_width, # width
|
307 |
+
canvas_height, # height
|
308 |
+
2, # num_samples_x
|
309 |
+
2, # num_samples_y
|
310 |
+
0, # seed
|
311 |
+
None,
|
312 |
+
*scene_args)
|
313 |
+
img = img[:, :, 3:4] * img[:, :, :3] + \
|
314 |
+
torch.ones(img.shape[0], img.shape[1], 3,
|
315 |
+
device=device) * (1 - img[:, :, 3:4])
|
316 |
+
img = img[:, :, :3]
|
317 |
+
return img
|
requirements.txt
ADDED
@@ -0,0 +1,53 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
ftfy==6.0.3
|
2 |
+
cssutils==2.3.0
|
3 |
+
gdown==4.4.0
|
4 |
+
imageio==2.9.0
|
5 |
+
imageio-ffmpeg==0.4.4
|
6 |
+
importlib-metadata==4.6.4
|
7 |
+
ipykernel==6.1.0
|
8 |
+
ipython==7.26.0
|
9 |
+
ipython-genutils==0.2.0
|
10 |
+
json5==0.9.5
|
11 |
+
jsonpatch==1.32
|
12 |
+
jsonpointer==2.1
|
13 |
+
jsonschema==3.2.0
|
14 |
+
jupyter-client==6.1.12
|
15 |
+
jupyter-core==4.7.1
|
16 |
+
jupyter-server==1.10.2
|
17 |
+
jupyterlab==3.1.6
|
18 |
+
jupyterlab-pygments==0.1.2
|
19 |
+
jupyterlab-server==2.7.0
|
20 |
+
matplotlib==3.4.2
|
21 |
+
matplotlib-inline==0.1.2
|
22 |
+
moviepy==1.0.3
|
23 |
+
notebook==6.4.3
|
24 |
+
numba==0.53.1
|
25 |
+
numpy==1.20.3
|
26 |
+
nvidia-ml-py3==7.352.0
|
27 |
+
opencv-python==4.5.3.56
|
28 |
+
pandas==1.3.2
|
29 |
+
pathtools==0.1.2
|
30 |
+
Pillow==8.2.0
|
31 |
+
pip==21.2.2
|
32 |
+
plotly==5.2.1
|
33 |
+
psutil==5.8.0
|
34 |
+
ptyprocess==0.7.0
|
35 |
+
pyaml==21.8.3
|
36 |
+
regex==2021.11.10
|
37 |
+
scikit-image==0.18.1
|
38 |
+
scikit-learn==1.0.2
|
39 |
+
scipy==1.6.2
|
40 |
+
seaborn==0.11.2
|
41 |
+
subprocess32==3.5.4
|
42 |
+
svgpathtools==1.4.1
|
43 |
+
svgwrite==1.4.1
|
44 |
+
torch==1.7.1
|
45 |
+
torch-tools==0.1.5
|
46 |
+
torchfile==0.1.0
|
47 |
+
torchvision==0.8.2
|
48 |
+
tqdm==4.62.1
|
49 |
+
visdom==0.1.8.9
|
50 |
+
wandb==0.12.0
|
51 |
+
webencodings==0.5.1
|
52 |
+
websocket-client==0.57.0
|
53 |
+
zipp==3.5.0
|
run_object_sketching.py
ADDED
@@ -0,0 +1,158 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import sys
|
2 |
+
import warnings
|
3 |
+
|
4 |
+
warnings.filterwarnings('ignore')
|
5 |
+
warnings.simplefilter('ignore')
|
6 |
+
|
7 |
+
import argparse
|
8 |
+
import multiprocessing as mp
|
9 |
+
import os
|
10 |
+
import subprocess as sp
|
11 |
+
from shutil import copyfile
|
12 |
+
|
13 |
+
import numpy as np
|
14 |
+
import torch
|
15 |
+
from IPython.display import Image as Image_colab
|
16 |
+
from IPython.display import display, SVG, clear_output
|
17 |
+
from ipywidgets import IntSlider, Output, IntProgress, Button
|
18 |
+
import time
|
19 |
+
|
20 |
+
parser = argparse.ArgumentParser()
|
21 |
+
parser.add_argument("--target_file", type=str,
|
22 |
+
help="target image file, located in <target_images>")
|
23 |
+
parser.add_argument("--num_strokes", type=int, default=16,
|
24 |
+
help="number of strokes used to generate the sketch, this defines the level of abstraction.")
|
25 |
+
parser.add_argument("--num_iter", type=int, default=2001,
|
26 |
+
help="number of iterations")
|
27 |
+
parser.add_argument("--fix_scale", type=int, default=0,
|
28 |
+
help="if the target image is not squared, it is recommended to fix the scale")
|
29 |
+
parser.add_argument("--mask_object", type=int, default=0,
|
30 |
+
help="if the target image contains background, it's better to mask it out")
|
31 |
+
parser.add_argument("--num_sketches", type=int, default=3,
|
32 |
+
help="it is recommended to draw 3 sketches and automatically chose the best one")
|
33 |
+
parser.add_argument("--multiprocess", type=int, default=0,
|
34 |
+
help="recommended to use multiprocess if your computer has enough memory")
|
35 |
+
parser.add_argument('-colab', action='store_true')
|
36 |
+
parser.add_argument('-cpu', action='store_true')
|
37 |
+
parser.add_argument('-display', action='store_true')
|
38 |
+
parser.add_argument('--gpunum', type=int, default=0)
|
39 |
+
|
40 |
+
args = parser.parse_args()
|
41 |
+
|
42 |
+
multiprocess = not args.colab and args.num_sketches > 1 and args.multiprocess
|
43 |
+
|
44 |
+
abs_path = os.path.abspath(os.getcwd())
|
45 |
+
|
46 |
+
target = f"{abs_path}/target_images/{args.target_file}"
|
47 |
+
assert os.path.isfile(target), f"{target} does not exists!"
|
48 |
+
|
49 |
+
if not os.path.isfile(f"{abs_path}/U2Net_/saved_models/u2net.pth"):
|
50 |
+
sp.run(["gdown", "https://drive.google.com/uc?id=1ao1ovG1Qtx4b7EoskHXmi2E9rp5CHLcZ",
|
51 |
+
"-O", "U2Net_/saved_models/"])
|
52 |
+
|
53 |
+
test_name = os.path.splitext(args.target_file)[0]
|
54 |
+
output_dir = f"{abs_path}/output_sketches/{test_name}/"
|
55 |
+
if not os.path.exists(output_dir):
|
56 |
+
os.makedirs(output_dir)
|
57 |
+
|
58 |
+
num_iter = args.num_iter
|
59 |
+
save_interval = 10
|
60 |
+
use_gpu = not args.cpu
|
61 |
+
|
62 |
+
if not torch.cuda.is_available():
|
63 |
+
use_gpu = False
|
64 |
+
print("CUDA is not configured with GPU, running with CPU instead.")
|
65 |
+
print("Note that this will be very slow, it is recommended to use colab.")
|
66 |
+
|
67 |
+
if args.colab:
|
68 |
+
print("=" * 50)
|
69 |
+
print(f"Processing [{args.target_file}] ...")
|
70 |
+
if args.colab or args.display:
|
71 |
+
img_ = Image_colab(target)
|
72 |
+
display(img_)
|
73 |
+
print(f"GPU: {use_gpu}, {torch.cuda.current_device()}")
|
74 |
+
print(f"Results will be saved to \n[{output_dir}] ...")
|
75 |
+
print("=" * 50)
|
76 |
+
|
77 |
+
seeds = list(range(0, args.num_sketches * 1000, 1000))
|
78 |
+
|
79 |
+
exit_codes = []
|
80 |
+
manager = mp.Manager()
|
81 |
+
losses_all = manager.dict()
|
82 |
+
|
83 |
+
|
84 |
+
def run(seed, wandb_name):
|
85 |
+
exit_code = sp.run(["python", "painterly_rendering.py", target,
|
86 |
+
"--num_paths", str(args.num_strokes),
|
87 |
+
"--output_dir", output_dir,
|
88 |
+
"--wandb_name", wandb_name,
|
89 |
+
"--num_iter", str(num_iter),
|
90 |
+
"--save_interval", str(save_interval),
|
91 |
+
"--seed", str(seed),
|
92 |
+
"--use_gpu", str(int(use_gpu)),
|
93 |
+
"--fix_scale", str(args.fix_scale),
|
94 |
+
"--mask_object", str(args.mask_object),
|
95 |
+
"--mask_object_attention", str(
|
96 |
+
args.mask_object),
|
97 |
+
"--display_logs", str(int(args.colab)),
|
98 |
+
"--display", str(int(args.display))])
|
99 |
+
if exit_code.returncode:
|
100 |
+
sys.exit(1)
|
101 |
+
|
102 |
+
config = np.load(f"{output_dir}/{wandb_name}/config.npy",
|
103 |
+
allow_pickle=True)[()]
|
104 |
+
loss_eval = np.array(config['loss_eval'])
|
105 |
+
inds = np.argsort(loss_eval)
|
106 |
+
losses_all[wandb_name] = loss_eval[inds][0]
|
107 |
+
|
108 |
+
|
109 |
+
def display_(seed, wandb_name):
|
110 |
+
path_to_svg = f"{output_dir}/{wandb_name}/svg_logs/"
|
111 |
+
intervals_ = list(range(0, num_iter, save_interval))
|
112 |
+
filename = f"svg_iter0.svg"
|
113 |
+
display(IntSlider())
|
114 |
+
out = Output()
|
115 |
+
display(out)
|
116 |
+
for i in intervals_:
|
117 |
+
filename = f"svg_iter{i}.svg"
|
118 |
+
not_exist = True
|
119 |
+
while not_exist:
|
120 |
+
not_exist = not os.path.isfile(f"{path_to_svg}/{filename}")
|
121 |
+
continue
|
122 |
+
with out:
|
123 |
+
clear_output()
|
124 |
+
print("")
|
125 |
+
display(IntProgress(
|
126 |
+
value=i,
|
127 |
+
min=0,
|
128 |
+
max=num_iter,
|
129 |
+
description='Processing:',
|
130 |
+
bar_style='info', # 'success', 'info', 'warning', 'danger' or ''
|
131 |
+
style={'bar_color': 'maroon'},
|
132 |
+
orientation='horizontal'
|
133 |
+
))
|
134 |
+
display(SVG(f"{path_to_svg}/svg_iter{i}.svg"))
|
135 |
+
|
136 |
+
|
137 |
+
|
138 |
+
if multiprocess:
|
139 |
+
ncpus = 10
|
140 |
+
P = mp.Pool(ncpus) # Generate pool of workers
|
141 |
+
|
142 |
+
for seed in seeds:
|
143 |
+
wandb_name = f"{test_name}_{args.num_strokes}strokes_seed{seed}"
|
144 |
+
if multiprocess:
|
145 |
+
P.apply_async(run, (seed, wandb_name))
|
146 |
+
else:
|
147 |
+
run(seed, wandb_name)
|
148 |
+
|
149 |
+
if args.display:
|
150 |
+
time.sleep(10)
|
151 |
+
P.apply_async(display_, (0, f"{test_name}_{args.num_strokes}strokes_seed0"))
|
152 |
+
|
153 |
+
if multiprocess:
|
154 |
+
P.close()
|
155 |
+
P.join() # start processes
|
156 |
+
sorted_final = dict(sorted(losses_all.items(), key=lambda item: item[1]))
|
157 |
+
copyfile(f"{output_dir}/{list(sorted_final.keys())[0]}/best_iter.svg",
|
158 |
+
f"{output_dir}/{list(sorted_final.keys())[0]}_best.svg")
|
sketch_utils.py
ADDED
@@ -0,0 +1,295 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
|
3 |
+
import imageio
|
4 |
+
import matplotlib.pyplot as plt
|
5 |
+
import numpy as np
|
6 |
+
import pandas as pd
|
7 |
+
import pydiffvg
|
8 |
+
import skimage
|
9 |
+
import skimage.io
|
10 |
+
import torch
|
11 |
+
import wandb
|
12 |
+
import PIL
|
13 |
+
from PIL import Image
|
14 |
+
from torchvision import transforms
|
15 |
+
from torchvision.utils import make_grid
|
16 |
+
from skimage.transform import resize
|
17 |
+
|
18 |
+
from U2Net_.model import U2NET
|
19 |
+
|
20 |
+
|
21 |
+
def imwrite(img, filename, gamma=2.2, normalize=False, use_wandb=False, wandb_name="", step=0, input_im=None):
|
22 |
+
directory = os.path.dirname(filename)
|
23 |
+
if directory != '' and not os.path.exists(directory):
|
24 |
+
os.makedirs(directory)
|
25 |
+
|
26 |
+
if not isinstance(img, np.ndarray):
|
27 |
+
img = img.data.numpy()
|
28 |
+
if normalize:
|
29 |
+
img_rng = np.max(img) - np.min(img)
|
30 |
+
if img_rng > 0:
|
31 |
+
img = (img - np.min(img)) / img_rng
|
32 |
+
img = np.clip(img, 0.0, 1.0)
|
33 |
+
if img.ndim == 2:
|
34 |
+
# repeat along the third dimension
|
35 |
+
img = np.expand_dims(img, 2)
|
36 |
+
img[:, :, :3] = np.power(img[:, :, :3], 1.0/gamma)
|
37 |
+
img = (img * 255).astype(np.uint8)
|
38 |
+
|
39 |
+
skimage.io.imsave(filename, img, check_contrast=False)
|
40 |
+
images = [wandb.Image(Image.fromarray(img), caption="output")]
|
41 |
+
if input_im is not None and step == 0:
|
42 |
+
images.append(wandb.Image(input_im, caption="input"))
|
43 |
+
if use_wandb:
|
44 |
+
wandb.log({wandb_name + "_": images}, step=step)
|
45 |
+
|
46 |
+
|
47 |
+
def plot_batch(inputs, outputs, output_dir, step, use_wandb, title):
|
48 |
+
plt.figure()
|
49 |
+
plt.subplot(2, 1, 1)
|
50 |
+
grid = make_grid(inputs.clone().detach(), normalize=True, pad_value=2)
|
51 |
+
npgrid = grid.cpu().numpy()
|
52 |
+
plt.imshow(np.transpose(npgrid, (1, 2, 0)), interpolation='nearest')
|
53 |
+
plt.axis("off")
|
54 |
+
plt.title("inputs")
|
55 |
+
|
56 |
+
plt.subplot(2, 1, 2)
|
57 |
+
grid = make_grid(outputs, normalize=False, pad_value=2)
|
58 |
+
npgrid = grid.detach().cpu().numpy()
|
59 |
+
plt.imshow(np.transpose(npgrid, (1, 2, 0)), interpolation='nearest')
|
60 |
+
plt.axis("off")
|
61 |
+
plt.title("outputs")
|
62 |
+
|
63 |
+
plt.tight_layout()
|
64 |
+
if use_wandb:
|
65 |
+
wandb.log({"output": wandb.Image(plt)}, step=step)
|
66 |
+
plt.savefig("{}/{}".format(output_dir, title))
|
67 |
+
plt.close()
|
68 |
+
|
69 |
+
|
70 |
+
def log_input(use_wandb, epoch, inputs, output_dir):
|
71 |
+
grid = make_grid(inputs.clone().detach(), normalize=True, pad_value=2)
|
72 |
+
npgrid = grid.cpu().numpy()
|
73 |
+
plt.imshow(np.transpose(npgrid, (1, 2, 0)), interpolation='nearest')
|
74 |
+
plt.axis("off")
|
75 |
+
plt.tight_layout()
|
76 |
+
if use_wandb:
|
77 |
+
wandb.log({"input": wandb.Image(plt)}, step=epoch)
|
78 |
+
plt.close()
|
79 |
+
input_ = inputs[0].cpu().clone().detach().permute(1, 2, 0).numpy()
|
80 |
+
input_ = (input_ - input_.min()) / (input_.max() - input_.min())
|
81 |
+
input_ = (input_ * 255).astype(np.uint8)
|
82 |
+
imageio.imwrite("{}/{}.png".format(output_dir, "input"), input_)
|
83 |
+
|
84 |
+
|
85 |
+
def log_sketch_summary_final(path_svg, use_wandb, device, epoch, loss, title):
|
86 |
+
canvas_width, canvas_height, shapes, shape_groups = load_svg(path_svg)
|
87 |
+
_render = pydiffvg.RenderFunction.apply
|
88 |
+
scene_args = pydiffvg.RenderFunction.serialize_scene(
|
89 |
+
canvas_width, canvas_height, shapes, shape_groups)
|
90 |
+
img = _render(canvas_width, # width
|
91 |
+
canvas_height, # height
|
92 |
+
2, # num_samples_x
|
93 |
+
2, # num_samples_y
|
94 |
+
0, # seed
|
95 |
+
None,
|
96 |
+
*scene_args)
|
97 |
+
|
98 |
+
img = img[:, :, 3:4] * img[:, :, :3] + \
|
99 |
+
torch.ones(img.shape[0], img.shape[1], 3,
|
100 |
+
device=device) * (1 - img[:, :, 3:4])
|
101 |
+
img = img[:, :, :3]
|
102 |
+
plt.imshow(img.cpu().numpy())
|
103 |
+
plt.axis("off")
|
104 |
+
plt.title(f"{title} best res [{epoch}] [{loss}.]")
|
105 |
+
if use_wandb:
|
106 |
+
wandb.log({title: wandb.Image(plt)})
|
107 |
+
plt.close()
|
108 |
+
|
109 |
+
|
110 |
+
def log_sketch_summary(sketch, title, use_wandb):
|
111 |
+
plt.figure()
|
112 |
+
grid = make_grid(sketch.clone().detach(), normalize=True, pad_value=2)
|
113 |
+
npgrid = grid.cpu().numpy()
|
114 |
+
plt.imshow(np.transpose(npgrid, (1, 2, 0)), interpolation='nearest')
|
115 |
+
plt.axis("off")
|
116 |
+
plt.title(title)
|
117 |
+
plt.tight_layout()
|
118 |
+
if use_wandb:
|
119 |
+
wandb.run.summary["best_loss_im"] = wandb.Image(plt)
|
120 |
+
plt.close()
|
121 |
+
|
122 |
+
|
123 |
+
def load_svg(path_svg):
|
124 |
+
svg = os.path.join(path_svg)
|
125 |
+
canvas_width, canvas_height, shapes, shape_groups = pydiffvg.svg_to_scene(
|
126 |
+
svg)
|
127 |
+
return canvas_width, canvas_height, shapes, shape_groups
|
128 |
+
|
129 |
+
|
130 |
+
def read_svg(path_svg, device, multiply=False):
|
131 |
+
canvas_width, canvas_height, shapes, shape_groups = pydiffvg.svg_to_scene(
|
132 |
+
path_svg)
|
133 |
+
if multiply:
|
134 |
+
canvas_width *= 2
|
135 |
+
canvas_height *= 2
|
136 |
+
for path in shapes:
|
137 |
+
path.points *= 2
|
138 |
+
path.stroke_width *= 2
|
139 |
+
_render = pydiffvg.RenderFunction.apply
|
140 |
+
scene_args = pydiffvg.RenderFunction.serialize_scene(
|
141 |
+
canvas_width, canvas_height, shapes, shape_groups)
|
142 |
+
img = _render(canvas_width, # width
|
143 |
+
canvas_height, # height
|
144 |
+
2, # num_samples_x
|
145 |
+
2, # num_samples_y
|
146 |
+
0, # seed
|
147 |
+
None,
|
148 |
+
*scene_args)
|
149 |
+
img = img[:, :, 3:4] * img[:, :, :3] + \
|
150 |
+
torch.ones(img.shape[0], img.shape[1], 3,
|
151 |
+
device=device) * (1 - img[:, :, 3:4])
|
152 |
+
img = img[:, :, :3]
|
153 |
+
return img
|
154 |
+
|
155 |
+
|
156 |
+
def plot_attn_dino(attn, threshold_map, inputs, inds, use_wandb, output_path):
|
157 |
+
# currently supports one image (and not a batch)
|
158 |
+
plt.figure(figsize=(10, 5))
|
159 |
+
|
160 |
+
plt.subplot(2, attn.shape[0] + 2, 1)
|
161 |
+
main_im = make_grid(inputs, normalize=True, pad_value=2)
|
162 |
+
main_im = np.transpose(main_im.cpu().numpy(), (1, 2, 0))
|
163 |
+
plt.imshow(main_im, interpolation='nearest')
|
164 |
+
plt.scatter(inds[:, 1], inds[:, 0], s=10, c='red', marker='o')
|
165 |
+
plt.title("input im")
|
166 |
+
plt.axis("off")
|
167 |
+
|
168 |
+
plt.subplot(2, attn.shape[0] + 2, 2)
|
169 |
+
plt.imshow(attn.sum(0).numpy(), interpolation='nearest')
|
170 |
+
plt.title("atn map sum")
|
171 |
+
plt.axis("off")
|
172 |
+
|
173 |
+
plt.subplot(2, attn.shape[0] + 2, attn.shape[0] + 3)
|
174 |
+
plt.imshow(threshold_map[-1].numpy(), interpolation='nearest')
|
175 |
+
plt.title("prob sum")
|
176 |
+
plt.axis("off")
|
177 |
+
|
178 |
+
plt.subplot(2, attn.shape[0] + 2, attn.shape[0] + 4)
|
179 |
+
plt.imshow(threshold_map[:-1].sum(0).numpy(), interpolation='nearest')
|
180 |
+
plt.title("thresh sum")
|
181 |
+
plt.axis("off")
|
182 |
+
|
183 |
+
for i in range(attn.shape[0]):
|
184 |
+
plt.subplot(2, attn.shape[0] + 2, i + 3)
|
185 |
+
plt.imshow(attn[i].numpy())
|
186 |
+
plt.axis("off")
|
187 |
+
plt.subplot(2, attn.shape[0] + 2, attn.shape[0] + 1 + i + 4)
|
188 |
+
plt.imshow(threshold_map[i].numpy())
|
189 |
+
plt.axis("off")
|
190 |
+
plt.tight_layout()
|
191 |
+
if use_wandb:
|
192 |
+
wandb.log({"attention_map": wandb.Image(plt)})
|
193 |
+
plt.savefig(output_path)
|
194 |
+
plt.close()
|
195 |
+
|
196 |
+
|
197 |
+
def plot_attn_clip(attn, threshold_map, inputs, inds, use_wandb, output_path, display_logs):
|
198 |
+
# currently supports one image (and not a batch)
|
199 |
+
plt.figure(figsize=(10, 5))
|
200 |
+
|
201 |
+
plt.subplot(1, 3, 1)
|
202 |
+
main_im = make_grid(inputs, normalize=True, pad_value=2)
|
203 |
+
main_im = np.transpose(main_im.cpu().numpy(), (1, 2, 0))
|
204 |
+
plt.imshow(main_im, interpolation='nearest')
|
205 |
+
plt.scatter(inds[:, 1], inds[:, 0], s=10, c='red', marker='o')
|
206 |
+
plt.title("input im")
|
207 |
+
plt.axis("off")
|
208 |
+
|
209 |
+
plt.subplot(1, 3, 2)
|
210 |
+
plt.imshow(attn, interpolation='nearest', vmin=0, vmax=1)
|
211 |
+
plt.title("atn map")
|
212 |
+
plt.axis("off")
|
213 |
+
|
214 |
+
plt.subplot(1, 3, 3)
|
215 |
+
threshold_map_ = (threshold_map - threshold_map.min()) / \
|
216 |
+
(threshold_map.max() - threshold_map.min())
|
217 |
+
plt.imshow(threshold_map_, interpolation='nearest', vmin=0, vmax=1)
|
218 |
+
plt.title("prob softmax")
|
219 |
+
plt.scatter(inds[:, 1], inds[:, 0], s=10, c='red', marker='o')
|
220 |
+
plt.axis("off")
|
221 |
+
|
222 |
+
plt.tight_layout()
|
223 |
+
if use_wandb:
|
224 |
+
wandb.log({"attention_map": wandb.Image(plt)})
|
225 |
+
plt.savefig(output_path)
|
226 |
+
plt.close()
|
227 |
+
|
228 |
+
|
229 |
+
def plot_atten(attn, threshold_map, inputs, inds, use_wandb, output_path, saliency_model, display_logs):
|
230 |
+
if saliency_model == "dino":
|
231 |
+
plot_attn_dino(attn, threshold_map, inputs,
|
232 |
+
inds, use_wandb, output_path)
|
233 |
+
elif saliency_model == "clip":
|
234 |
+
plot_attn_clip(attn, threshold_map, inputs, inds,
|
235 |
+
use_wandb, output_path, display_logs)
|
236 |
+
|
237 |
+
|
238 |
+
def fix_image_scale(im):
|
239 |
+
im_np = np.array(im) / 255
|
240 |
+
height, width = im_np.shape[0], im_np.shape[1]
|
241 |
+
max_len = max(height, width) + 20
|
242 |
+
new_background = np.ones((max_len, max_len, 3))
|
243 |
+
y, x = max_len // 2 - height // 2, max_len // 2 - width // 2
|
244 |
+
new_background[y: y + height, x: x + width] = im_np
|
245 |
+
new_background = (new_background / new_background.max()
|
246 |
+
* 255).astype(np.uint8)
|
247 |
+
new_im = Image.fromarray(new_background)
|
248 |
+
return new_im
|
249 |
+
|
250 |
+
|
251 |
+
def get_mask_u2net(args, pil_im):
|
252 |
+
w, h = pil_im.size[0], pil_im.size[1]
|
253 |
+
im_size = min(w, h)
|
254 |
+
data_transforms = transforms.Compose([
|
255 |
+
transforms.Resize(min(320, im_size), interpolation=PIL.Image.BICUBIC),
|
256 |
+
transforms.ToTensor(),
|
257 |
+
transforms.Normalize(mean=(0.48145466, 0.4578275, 0.40821073), std=(
|
258 |
+
0.26862954, 0.26130258, 0.27577711)),
|
259 |
+
])
|
260 |
+
|
261 |
+
input_im_trans = data_transforms(pil_im).unsqueeze(0).to(args.device)
|
262 |
+
|
263 |
+
model_dir = os.path.join("./U2Net_/saved_models/u2net.pth")
|
264 |
+
net = U2NET(3, 1)
|
265 |
+
if torch.cuda.is_available() and args.use_gpu:
|
266 |
+
net.load_state_dict(torch.load(model_dir))
|
267 |
+
net.to(args.device)
|
268 |
+
else:
|
269 |
+
net.load_state_dict(torch.load(model_dir, map_location='cpu'))
|
270 |
+
net.eval()
|
271 |
+
with torch.no_grad():
|
272 |
+
d1, d2, d3, d4, d5, d6, d7 = net(input_im_trans.detach())
|
273 |
+
pred = d1[:, 0, :, :]
|
274 |
+
pred = (pred - pred.min()) / (pred.max() - pred.min())
|
275 |
+
predict = pred
|
276 |
+
predict[predict < 0.5] = 0
|
277 |
+
predict[predict >= 0.5] = 1
|
278 |
+
mask = torch.cat([predict, predict, predict], axis=0).permute(1, 2, 0)
|
279 |
+
mask = mask.cpu().numpy()
|
280 |
+
mask = resize(mask, (h, w), anti_aliasing=False)
|
281 |
+
mask[mask < 0.5] = 0
|
282 |
+
mask[mask >= 0.5] = 1
|
283 |
+
|
284 |
+
# predict_np = predict.clone().cpu().data.numpy()
|
285 |
+
im = Image.fromarray((mask[:, :, 0]*255).astype(np.uint8)).convert('RGB')
|
286 |
+
im.save(f"{args.output_dir}/mask.png")
|
287 |
+
|
288 |
+
im_np = np.array(pil_im)
|
289 |
+
im_np = im_np / im_np.max()
|
290 |
+
im_np = mask * im_np
|
291 |
+
im_np[mask == 0] = 1
|
292 |
+
im_final = (im_np / im_np.max() * 255).astype(np.uint8)
|
293 |
+
im_final = Image.fromarray(im_final)
|
294 |
+
|
295 |
+
return im_final, predict
|
target_images/camel.png
ADDED
target_images/flamingo.png
ADDED
target_images/horse.png
ADDED
target_images/rose.jpeg
ADDED