yael-vinker commited on
Commit
3c149ed
1 Parent(s): 253b0de
CLIPasso-local-demo.ipynb ADDED
The diff for this file is too large to render. See raw diff
 
CLIPasso.ipynb ADDED
The diff for this file is too large to render. See raw diff
 
LICENSE ADDED
@@ -0,0 +1,437 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Attribution-NonCommercial-ShareAlike 4.0 International
2
+
3
+ =======================================================================
4
+
5
+ Creative Commons Corporation ("Creative Commons") is not a law firm and
6
+ does not provide legal services or legal advice. Distribution of
7
+ Creative Commons public licenses does not create a lawyer-client or
8
+ other relationship. Creative Commons makes its licenses and related
9
+ information available on an "as-is" basis. Creative Commons gives no
10
+ warranties regarding its licenses, any material licensed under their
11
+ terms and conditions, or any related information. Creative Commons
12
+ disclaims all liability for damages resulting from their use to the
13
+ fullest extent possible.
14
+
15
+ Using Creative Commons Public Licenses
16
+
17
+ Creative Commons public licenses provide a standard set of terms and
18
+ conditions that creators and other rights holders may use to share
19
+ original works of authorship and other material subject to copyright
20
+ and certain other rights specified in the public license below. The
21
+ following considerations are for informational purposes only, are not
22
+ exhaustive, and do not form part of our licenses.
23
+
24
+ Considerations for licensors: Our public licenses are
25
+ intended for use by those authorized to give the public
26
+ permission to use material in ways otherwise restricted by
27
+ copyright and certain other rights. Our licenses are
28
+ irrevocable. Licensors should read and understand the terms
29
+ and conditions of the license they choose before applying it.
30
+ Licensors should also secure all rights necessary before
31
+ applying our licenses so that the public can reuse the
32
+ material as expected. Licensors should clearly mark any
33
+ material not subject to the license. This includes other CC-
34
+ licensed material, or material used under an exception or
35
+ limitation to copyright. More considerations for licensors:
36
+ wiki.creativecommons.org/Considerations_for_licensors
37
+
38
+ Considerations for the public: By using one of our public
39
+ licenses, a licensor grants the public permission to use the
40
+ licensed material under specified terms and conditions. If
41
+ the licensor's permission is not necessary for any reason--for
42
+ example, because of any applicable exception or limitation to
43
+ copyright--then that use is not regulated by the license. Our
44
+ licenses grant only permissions under copyright and certain
45
+ other rights that a licensor has authority to grant. Use of
46
+ the licensed material may still be restricted for other
47
+ reasons, including because others have copyright or other
48
+ rights in the material. A licensor may make special requests,
49
+ such as asking that all changes be marked or described.
50
+ Although not required by our licenses, you are encouraged to
51
+ respect those requests where reasonable. More considerations
52
+ for the public:
53
+ wiki.creativecommons.org/Considerations_for_licensees
54
+
55
+ =======================================================================
56
+
57
+ Creative Commons Attribution-NonCommercial-ShareAlike 4.0 International
58
+ Public License
59
+
60
+ By exercising the Licensed Rights (defined below), You accept and agree
61
+ to be bound by the terms and conditions of this Creative Commons
62
+ Attribution-NonCommercial-ShareAlike 4.0 International Public License
63
+ ("Public License"). To the extent this Public License may be
64
+ interpreted as a contract, You are granted the Licensed Rights in
65
+ consideration of Your acceptance of these terms and conditions, and the
66
+ Licensor grants You such rights in consideration of benefits the
67
+ Licensor receives from making the Licensed Material available under
68
+ these terms and conditions.
69
+
70
+
71
+ Section 1 -- Definitions.
72
+
73
+ a. Adapted Material means material subject to Copyright and Similar
74
+ Rights that is derived from or based upon the Licensed Material
75
+ and in which the Licensed Material is translated, altered,
76
+ arranged, transformed, or otherwise modified in a manner requiring
77
+ permission under the Copyright and Similar Rights held by the
78
+ Licensor. For purposes of this Public License, where the Licensed
79
+ Material is a musical work, performance, or sound recording,
80
+ Adapted Material is always produced where the Licensed Material is
81
+ synched in timed relation with a moving image.
82
+
83
+ b. Adapter's License means the license You apply to Your Copyright
84
+ and Similar Rights in Your contributions to Adapted Material in
85
+ accordance with the terms and conditions of this Public License.
86
+
87
+ c. BY-NC-SA Compatible License means a license listed at
88
+ creativecommons.org/compatiblelicenses, approved by Creative
89
+ Commons as essentially the equivalent of this Public License.
90
+
91
+ d. Copyright and Similar Rights means copyright and/or similar rights
92
+ closely related to copyright including, without limitation,
93
+ performance, broadcast, sound recording, and Sui Generis Database
94
+ Rights, without regard to how the rights are labeled or
95
+ categorized. For purposes of this Public License, the rights
96
+ specified in Section 2(b)(1)-(2) are not Copyright and Similar
97
+ Rights.
98
+
99
+ e. Effective Technological Measures means those measures that, in the
100
+ absence of proper authority, may not be circumvented under laws
101
+ fulfilling obligations under Article 11 of the WIPO Copyright
102
+ Treaty adopted on December 20, 1996, and/or similar international
103
+ agreements.
104
+
105
+ f. Exceptions and Limitations means fair use, fair dealing, and/or
106
+ any other exception or limitation to Copyright and Similar Rights
107
+ that applies to Your use of the Licensed Material.
108
+
109
+ g. License Elements means the license attributes listed in the name
110
+ of a Creative Commons Public License. The License Elements of this
111
+ Public License are Attribution, NonCommercial, and ShareAlike.
112
+
113
+ h. Licensed Material means the artistic or literary work, database,
114
+ or other material to which the Licensor applied this Public
115
+ License.
116
+
117
+ i. Licensed Rights means the rights granted to You subject to the
118
+ terms and conditions of this Public License, which are limited to
119
+ all Copyright and Similar Rights that apply to Your use of the
120
+ Licensed Material and that the Licensor has authority to license.
121
+
122
+ j. Licensor means the individual(s) or entity(ies) granting rights
123
+ under this Public License.
124
+
125
+ k. NonCommercial means not primarily intended for or directed towards
126
+ commercial advantage or monetary compensation. For purposes of
127
+ this Public License, the exchange of the Licensed Material for
128
+ other material subject to Copyright and Similar Rights by digital
129
+ file-sharing or similar means is NonCommercial provided there is
130
+ no payment of monetary compensation in connection with the
131
+ exchange.
132
+
133
+ l. Share means to provide material to the public by any means or
134
+ process that requires permission under the Licensed Rights, such
135
+ as reproduction, public display, public performance, distribution,
136
+ dissemination, communication, or importation, and to make material
137
+ available to the public including in ways that members of the
138
+ public may access the material from a place and at a time
139
+ individually chosen by them.
140
+
141
+ m. Sui Generis Database Rights means rights other than copyright
142
+ resulting from Directive 96/9/EC of the European Parliament and of
143
+ the Council of 11 March 1996 on the legal protection of databases,
144
+ as amended and/or succeeded, as well as other essentially
145
+ equivalent rights anywhere in the world.
146
+
147
+ n. You means the individual or entity exercising the Licensed Rights
148
+ under this Public License. Your has a corresponding meaning.
149
+
150
+
151
+ Section 2 -- Scope.
152
+
153
+ a. License grant.
154
+
155
+ 1. Subject to the terms and conditions of this Public License,
156
+ the Licensor hereby grants You a worldwide, royalty-free,
157
+ non-sublicensable, non-exclusive, irrevocable license to
158
+ exercise the Licensed Rights in the Licensed Material to:
159
+
160
+ a. reproduce and Share the Licensed Material, in whole or
161
+ in part, for NonCommercial purposes only; and
162
+
163
+ b. produce, reproduce, and Share Adapted Material for
164
+ NonCommercial purposes only.
165
+
166
+ 2. Exceptions and Limitations. For the avoidance of doubt, where
167
+ Exceptions and Limitations apply to Your use, this Public
168
+ License does not apply, and You do not need to comply with
169
+ its terms and conditions.
170
+
171
+ 3. Term. The term of this Public License is specified in Section
172
+ 6(a).
173
+
174
+ 4. Media and formats; technical modifications allowed. The
175
+ Licensor authorizes You to exercise the Licensed Rights in
176
+ all media and formats whether now known or hereafter created,
177
+ and to make technical modifications necessary to do so. The
178
+ Licensor waives and/or agrees not to assert any right or
179
+ authority to forbid You from making technical modifications
180
+ necessary to exercise the Licensed Rights, including
181
+ technical modifications necessary to circumvent Effective
182
+ Technological Measures. For purposes of this Public License,
183
+ simply making modifications authorized by this Section 2(a)
184
+ (4) never produces Adapted Material.
185
+
186
+ 5. Downstream recipients.
187
+
188
+ a. Offer from the Licensor -- Licensed Material. Every
189
+ recipient of the Licensed Material automatically
190
+ receives an offer from the Licensor to exercise the
191
+ Licensed Rights under the terms and conditions of this
192
+ Public License.
193
+
194
+ b. Additional offer from the Licensor -- Adapted Material.
195
+ Every recipient of Adapted Material from You
196
+ automatically receives an offer from the Licensor to
197
+ exercise the Licensed Rights in the Adapted Material
198
+ under the conditions of the Adapter's License You apply.
199
+
200
+ c. No downstream restrictions. You may not offer or impose
201
+ any additional or different terms or conditions on, or
202
+ apply any Effective Technological Measures to, the
203
+ Licensed Material if doing so restricts exercise of the
204
+ Licensed Rights by any recipient of the Licensed
205
+ Material.
206
+
207
+ 6. No endorsement. Nothing in this Public License constitutes or
208
+ may be construed as permission to assert or imply that You
209
+ are, or that Your use of the Licensed Material is, connected
210
+ with, or sponsored, endorsed, or granted official status by,
211
+ the Licensor or others designated to receive attribution as
212
+ provided in Section 3(a)(1)(A)(i).
213
+
214
+ b. Other rights.
215
+
216
+ 1. Moral rights, such as the right of integrity, are not
217
+ licensed under this Public License, nor are publicity,
218
+ privacy, and/or other similar personality rights; however, to
219
+ the extent possible, the Licensor waives and/or agrees not to
220
+ assert any such rights held by the Licensor to the limited
221
+ extent necessary to allow You to exercise the Licensed
222
+ Rights, but not otherwise.
223
+
224
+ 2. Patent and trademark rights are not licensed under this
225
+ Public License.
226
+
227
+ 3. To the extent possible, the Licensor waives any right to
228
+ collect royalties from You for the exercise of the Licensed
229
+ Rights, whether directly or through a collecting society
230
+ under any voluntary or waivable statutory or compulsory
231
+ licensing scheme. In all other cases the Licensor expressly
232
+ reserves any right to collect such royalties, including when
233
+ the Licensed Material is used other than for NonCommercial
234
+ purposes.
235
+
236
+
237
+ Section 3 -- License Conditions.
238
+
239
+ Your exercise of the Licensed Rights is expressly made subject to the
240
+ following conditions.
241
+
242
+ a. Attribution.
243
+
244
+ 1. If You Share the Licensed Material (including in modified
245
+ form), You must:
246
+
247
+ a. retain the following if it is supplied by the Licensor
248
+ with the Licensed Material:
249
+
250
+ i. identification of the creator(s) of the Licensed
251
+ Material and any others designated to receive
252
+ attribution, in any reasonable manner requested by
253
+ the Licensor (including by pseudonym if
254
+ designated);
255
+
256
+ ii. a copyright notice;
257
+
258
+ iii. a notice that refers to this Public License;
259
+
260
+ iv. a notice that refers to the disclaimer of
261
+ warranties;
262
+
263
+ v. a URI or hyperlink to the Licensed Material to the
264
+ extent reasonably practicable;
265
+
266
+ b. indicate if You modified the Licensed Material and
267
+ retain an indication of any previous modifications; and
268
+
269
+ c. indicate the Licensed Material is licensed under this
270
+ Public License, and include the text of, or the URI or
271
+ hyperlink to, this Public License.
272
+
273
+ 2. You may satisfy the conditions in Section 3(a)(1) in any
274
+ reasonable manner based on the medium, means, and context in
275
+ which You Share the Licensed Material. For example, it may be
276
+ reasonable to satisfy the conditions by providing a URI or
277
+ hyperlink to a resource that includes the required
278
+ information.
279
+ 3. If requested by the Licensor, You must remove any of the
280
+ information required by Section 3(a)(1)(A) to the extent
281
+ reasonably practicable.
282
+
283
+ b. ShareAlike.
284
+
285
+ In addition to the conditions in Section 3(a), if You Share
286
+ Adapted Material You produce, the following conditions also apply.
287
+
288
+ 1. The Adapter's License You apply must be a Creative Commons
289
+ license with the same License Elements, this version or
290
+ later, or a BY-NC-SA Compatible License.
291
+
292
+ 2. You must include the text of, or the URI or hyperlink to, the
293
+ Adapter's License You apply. You may satisfy this condition
294
+ in any reasonable manner based on the medium, means, and
295
+ context in which You Share Adapted Material.
296
+
297
+ 3. You may not offer or impose any additional or different terms
298
+ or conditions on, or apply any Effective Technological
299
+ Measures to, Adapted Material that restrict exercise of the
300
+ rights granted under the Adapter's License You apply.
301
+
302
+
303
+ Section 4 -- Sui Generis Database Rights.
304
+
305
+ Where the Licensed Rights include Sui Generis Database Rights that
306
+ apply to Your use of the Licensed Material:
307
+
308
+ a. for the avoidance of doubt, Section 2(a)(1) grants You the right
309
+ to extract, reuse, reproduce, and Share all or a substantial
310
+ portion of the contents of the database for NonCommercial purposes
311
+ only;
312
+
313
+ b. if You include all or a substantial portion of the database
314
+ contents in a database in which You have Sui Generis Database
315
+ Rights, then the database in which You have Sui Generis Database
316
+ Rights (but not its individual contents) is Adapted Material,
317
+ including for purposes of Section 3(b); and
318
+
319
+ c. You must comply with the conditions in Section 3(a) if You Share
320
+ all or a substantial portion of the contents of the database.
321
+
322
+ For the avoidance of doubt, this Section 4 supplements and does not
323
+ replace Your obligations under this Public License where the Licensed
324
+ Rights include other Copyright and Similar Rights.
325
+
326
+
327
+ Section 5 -- Disclaimer of Warranties and Limitation of Liability.
328
+
329
+ a. UNLESS OTHERWISE SEPARATELY UNDERTAKEN BY THE LICENSOR, TO THE
330
+ EXTENT POSSIBLE, THE LICENSOR OFFERS THE LICENSED MATERIAL AS-IS
331
+ AND AS-AVAILABLE, AND MAKES NO REPRESENTATIONS OR WARRANTIES OF
332
+ ANY KIND CONCERNING THE LICENSED MATERIAL, WHETHER EXPRESS,
333
+ IMPLIED, STATUTORY, OR OTHER. THIS INCLUDES, WITHOUT LIMITATION,
334
+ WARRANTIES OF TITLE, MERCHANTABILITY, FITNESS FOR A PARTICULAR
335
+ PURPOSE, NON-INFRINGEMENT, ABSENCE OF LATENT OR OTHER DEFECTS,
336
+ ACCURACY, OR THE PRESENCE OR ABSENCE OF ERRORS, WHETHER OR NOT
337
+ KNOWN OR DISCOVERABLE. WHERE DISCLAIMERS OF WARRANTIES ARE NOT
338
+ ALLOWED IN FULL OR IN PART, THIS DISCLAIMER MAY NOT APPLY TO YOU.
339
+
340
+ b. TO THE EXTENT POSSIBLE, IN NO EVENT WILL THE LICENSOR BE LIABLE
341
+ TO YOU ON ANY LEGAL THEORY (INCLUDING, WITHOUT LIMITATION,
342
+ NEGLIGENCE) OR OTHERWISE FOR ANY DIRECT, SPECIAL, INDIRECT,
343
+ INCIDENTAL, CONSEQUENTIAL, PUNITIVE, EXEMPLARY, OR OTHER LOSSES,
344
+ COSTS, EXPENSES, OR DAMAGES ARISING OUT OF THIS PUBLIC LICENSE OR
345
+ USE OF THE LICENSED MATERIAL, EVEN IF THE LICENSOR HAS BEEN
346
+ ADVISED OF THE POSSIBILITY OF SUCH LOSSES, COSTS, EXPENSES, OR
347
+ DAMAGES. WHERE A LIMITATION OF LIABILITY IS NOT ALLOWED IN FULL OR
348
+ IN PART, THIS LIMITATION MAY NOT APPLY TO YOU.
349
+
350
+ c. The disclaimer of warranties and limitation of liability provided
351
+ above shall be interpreted in a manner that, to the extent
352
+ possible, most closely approximates an absolute disclaimer and
353
+ waiver of all liability.
354
+
355
+
356
+ Section 6 -- Term and Termination.
357
+
358
+ a. This Public License applies for the term of the Copyright and
359
+ Similar Rights licensed here. However, if You fail to comply with
360
+ this Public License, then Your rights under this Public License
361
+ terminate automatically.
362
+
363
+ b. Where Your right to use the Licensed Material has terminated under
364
+ Section 6(a), it reinstates:
365
+
366
+ 1. automatically as of the date the violation is cured, provided
367
+ it is cured within 30 days of Your discovery of the
368
+ violation; or
369
+
370
+ 2. upon express reinstatement by the Licensor.
371
+
372
+ For the avoidance of doubt, this Section 6(b) does not affect any
373
+ right the Licensor may have to seek remedies for Your violations
374
+ of this Public License.
375
+
376
+ c. For the avoidance of doubt, the Licensor may also offer the
377
+ Licensed Material under separate terms or conditions or stop
378
+ distributing the Licensed Material at any time; however, doing so
379
+ will not terminate this Public License.
380
+
381
+ d. Sections 1, 5, 6, 7, and 8 survive termination of this Public
382
+ License.
383
+
384
+
385
+ Section 7 -- Other Terms and Conditions.
386
+
387
+ a. The Licensor shall not be bound by any additional or different
388
+ terms or conditions communicated by You unless expressly agreed.
389
+
390
+ b. Any arrangements, understandings, or agreements regarding the
391
+ Licensed Material not stated herein are separate from and
392
+ independent of the terms and conditions of this Public License.
393
+
394
+
395
+ Section 8 -- Interpretation.
396
+
397
+ a. For the avoidance of doubt, this Public License does not, and
398
+ shall not be interpreted to, reduce, limit, restrict, or impose
399
+ conditions on any use of the Licensed Material that could lawfully
400
+ be made without permission under this Public License.
401
+
402
+ b. To the extent possible, if any provision of this Public License is
403
+ deemed unenforceable, it shall be automatically reformed to the
404
+ minimum extent necessary to make it enforceable. If the provision
405
+ cannot be reformed, it shall be severed from this Public License
406
+ without affecting the enforceability of the remaining terms and
407
+ conditions.
408
+
409
+ c. No term or condition of this Public License will be waived and no
410
+ failure to comply consented to unless expressly agreed to by the
411
+ Licensor.
412
+
413
+ d. Nothing in this Public License constitutes or may be interpreted
414
+ as a limitation upon, or waiver of, any privileges and immunities
415
+ that apply to the Licensor or You, including from the legal
416
+ processes of any jurisdiction or authority.
417
+
418
+ =======================================================================
419
+
420
+ Creative Commons is not a party to its public
421
+ licenses. Notwithstanding, Creative Commons may elect to apply one of
422
+ its public licenses to material it publishes and in those instances
423
+ will be considered the “Licensor.” The text of the Creative Commons
424
+ public licenses is dedicated to the public domain under the CC0 Public
425
+ Domain Dedication. Except for the limited purpose of indicating that
426
+ material is shared under a Creative Commons public license or as
427
+ otherwise permitted by the Creative Commons policies published at
428
+ creativecommons.org/policies, Creative Commons does not authorize the
429
+ use of the trademark "Creative Commons" or any other trademark or logo
430
+ of Creative Commons without its prior written consent including,
431
+ without limitation, in connection with any unauthorized modifications
432
+ to any of its public licenses or any other arrangements,
433
+ understandings, or agreements concerning use of licensed material. For
434
+ the avoidance of doubt, this paragraph does not form part of the
435
+ public licenses.
436
+
437
+ Creative Commons may be contacted at creativecommons.org.
U2Net_/model/__init__.py ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ from .u2net import U2NET
2
+ from .u2net import U2NETP
U2Net_/model/u2net.py ADDED
@@ -0,0 +1,525 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+
5
+ class REBNCONV(nn.Module):
6
+ def __init__(self,in_ch=3,out_ch=3,dirate=1):
7
+ super(REBNCONV,self).__init__()
8
+
9
+ self.conv_s1 = nn.Conv2d(in_ch,out_ch,3,padding=1*dirate,dilation=1*dirate)
10
+ self.bn_s1 = nn.BatchNorm2d(out_ch)
11
+ self.relu_s1 = nn.ReLU(inplace=True)
12
+
13
+ def forward(self,x):
14
+
15
+ hx = x
16
+ xout = self.relu_s1(self.bn_s1(self.conv_s1(hx)))
17
+
18
+ return xout
19
+
20
+ ## upsample tensor 'src' to have the same spatial size with tensor 'tar'
21
+ def _upsample_like(src,tar):
22
+
23
+ src = F.upsample(src,size=tar.shape[2:],mode='bilinear')
24
+
25
+ return src
26
+
27
+
28
+ ### RSU-7 ###
29
+ class RSU7(nn.Module):#UNet07DRES(nn.Module):
30
+
31
+ def __init__(self, in_ch=3, mid_ch=12, out_ch=3):
32
+ super(RSU7,self).__init__()
33
+
34
+ self.rebnconvin = REBNCONV(in_ch,out_ch,dirate=1)
35
+
36
+ self.rebnconv1 = REBNCONV(out_ch,mid_ch,dirate=1)
37
+ self.pool1 = nn.MaxPool2d(2,stride=2,ceil_mode=True)
38
+
39
+ self.rebnconv2 = REBNCONV(mid_ch,mid_ch,dirate=1)
40
+ self.pool2 = nn.MaxPool2d(2,stride=2,ceil_mode=True)
41
+
42
+ self.rebnconv3 = REBNCONV(mid_ch,mid_ch,dirate=1)
43
+ self.pool3 = nn.MaxPool2d(2,stride=2,ceil_mode=True)
44
+
45
+ self.rebnconv4 = REBNCONV(mid_ch,mid_ch,dirate=1)
46
+ self.pool4 = nn.MaxPool2d(2,stride=2,ceil_mode=True)
47
+
48
+ self.rebnconv5 = REBNCONV(mid_ch,mid_ch,dirate=1)
49
+ self.pool5 = nn.MaxPool2d(2,stride=2,ceil_mode=True)
50
+
51
+ self.rebnconv6 = REBNCONV(mid_ch,mid_ch,dirate=1)
52
+
53
+ self.rebnconv7 = REBNCONV(mid_ch,mid_ch,dirate=2)
54
+
55
+ self.rebnconv6d = REBNCONV(mid_ch*2,mid_ch,dirate=1)
56
+ self.rebnconv5d = REBNCONV(mid_ch*2,mid_ch,dirate=1)
57
+ self.rebnconv4d = REBNCONV(mid_ch*2,mid_ch,dirate=1)
58
+ self.rebnconv3d = REBNCONV(mid_ch*2,mid_ch,dirate=1)
59
+ self.rebnconv2d = REBNCONV(mid_ch*2,mid_ch,dirate=1)
60
+ self.rebnconv1d = REBNCONV(mid_ch*2,out_ch,dirate=1)
61
+
62
+ def forward(self,x):
63
+
64
+ hx = x
65
+ hxin = self.rebnconvin(hx)
66
+
67
+ hx1 = self.rebnconv1(hxin)
68
+ hx = self.pool1(hx1)
69
+
70
+ hx2 = self.rebnconv2(hx)
71
+ hx = self.pool2(hx2)
72
+
73
+ hx3 = self.rebnconv3(hx)
74
+ hx = self.pool3(hx3)
75
+
76
+ hx4 = self.rebnconv4(hx)
77
+ hx = self.pool4(hx4)
78
+
79
+ hx5 = self.rebnconv5(hx)
80
+ hx = self.pool5(hx5)
81
+
82
+ hx6 = self.rebnconv6(hx)
83
+
84
+ hx7 = self.rebnconv7(hx6)
85
+
86
+ hx6d = self.rebnconv6d(torch.cat((hx7,hx6),1))
87
+ hx6dup = _upsample_like(hx6d,hx5)
88
+
89
+ hx5d = self.rebnconv5d(torch.cat((hx6dup,hx5),1))
90
+ hx5dup = _upsample_like(hx5d,hx4)
91
+
92
+ hx4d = self.rebnconv4d(torch.cat((hx5dup,hx4),1))
93
+ hx4dup = _upsample_like(hx4d,hx3)
94
+
95
+ hx3d = self.rebnconv3d(torch.cat((hx4dup,hx3),1))
96
+ hx3dup = _upsample_like(hx3d,hx2)
97
+
98
+ hx2d = self.rebnconv2d(torch.cat((hx3dup,hx2),1))
99
+ hx2dup = _upsample_like(hx2d,hx1)
100
+
101
+ hx1d = self.rebnconv1d(torch.cat((hx2dup,hx1),1))
102
+
103
+ return hx1d + hxin
104
+
105
+ ### RSU-6 ###
106
+ class RSU6(nn.Module):#UNet06DRES(nn.Module):
107
+
108
+ def __init__(self, in_ch=3, mid_ch=12, out_ch=3):
109
+ super(RSU6,self).__init__()
110
+
111
+ self.rebnconvin = REBNCONV(in_ch,out_ch,dirate=1)
112
+
113
+ self.rebnconv1 = REBNCONV(out_ch,mid_ch,dirate=1)
114
+ self.pool1 = nn.MaxPool2d(2,stride=2,ceil_mode=True)
115
+
116
+ self.rebnconv2 = REBNCONV(mid_ch,mid_ch,dirate=1)
117
+ self.pool2 = nn.MaxPool2d(2,stride=2,ceil_mode=True)
118
+
119
+ self.rebnconv3 = REBNCONV(mid_ch,mid_ch,dirate=1)
120
+ self.pool3 = nn.MaxPool2d(2,stride=2,ceil_mode=True)
121
+
122
+ self.rebnconv4 = REBNCONV(mid_ch,mid_ch,dirate=1)
123
+ self.pool4 = nn.MaxPool2d(2,stride=2,ceil_mode=True)
124
+
125
+ self.rebnconv5 = REBNCONV(mid_ch,mid_ch,dirate=1)
126
+
127
+ self.rebnconv6 = REBNCONV(mid_ch,mid_ch,dirate=2)
128
+
129
+ self.rebnconv5d = REBNCONV(mid_ch*2,mid_ch,dirate=1)
130
+ self.rebnconv4d = REBNCONV(mid_ch*2,mid_ch,dirate=1)
131
+ self.rebnconv3d = REBNCONV(mid_ch*2,mid_ch,dirate=1)
132
+ self.rebnconv2d = REBNCONV(mid_ch*2,mid_ch,dirate=1)
133
+ self.rebnconv1d = REBNCONV(mid_ch*2,out_ch,dirate=1)
134
+
135
+ def forward(self,x):
136
+
137
+ hx = x
138
+
139
+ hxin = self.rebnconvin(hx)
140
+
141
+ hx1 = self.rebnconv1(hxin)
142
+ hx = self.pool1(hx1)
143
+
144
+ hx2 = self.rebnconv2(hx)
145
+ hx = self.pool2(hx2)
146
+
147
+ hx3 = self.rebnconv3(hx)
148
+ hx = self.pool3(hx3)
149
+
150
+ hx4 = self.rebnconv4(hx)
151
+ hx = self.pool4(hx4)
152
+
153
+ hx5 = self.rebnconv5(hx)
154
+
155
+ hx6 = self.rebnconv6(hx5)
156
+
157
+
158
+ hx5d = self.rebnconv5d(torch.cat((hx6,hx5),1))
159
+ hx5dup = _upsample_like(hx5d,hx4)
160
+
161
+ hx4d = self.rebnconv4d(torch.cat((hx5dup,hx4),1))
162
+ hx4dup = _upsample_like(hx4d,hx3)
163
+
164
+ hx3d = self.rebnconv3d(torch.cat((hx4dup,hx3),1))
165
+ hx3dup = _upsample_like(hx3d,hx2)
166
+
167
+ hx2d = self.rebnconv2d(torch.cat((hx3dup,hx2),1))
168
+ hx2dup = _upsample_like(hx2d,hx1)
169
+
170
+ hx1d = self.rebnconv1d(torch.cat((hx2dup,hx1),1))
171
+
172
+ return hx1d + hxin
173
+
174
+ ### RSU-5 ###
175
+ class RSU5(nn.Module):#UNet05DRES(nn.Module):
176
+
177
+ def __init__(self, in_ch=3, mid_ch=12, out_ch=3):
178
+ super(RSU5,self).__init__()
179
+
180
+ self.rebnconvin = REBNCONV(in_ch,out_ch,dirate=1)
181
+
182
+ self.rebnconv1 = REBNCONV(out_ch,mid_ch,dirate=1)
183
+ self.pool1 = nn.MaxPool2d(2,stride=2,ceil_mode=True)
184
+
185
+ self.rebnconv2 = REBNCONV(mid_ch,mid_ch,dirate=1)
186
+ self.pool2 = nn.MaxPool2d(2,stride=2,ceil_mode=True)
187
+
188
+ self.rebnconv3 = REBNCONV(mid_ch,mid_ch,dirate=1)
189
+ self.pool3 = nn.MaxPool2d(2,stride=2,ceil_mode=True)
190
+
191
+ self.rebnconv4 = REBNCONV(mid_ch,mid_ch,dirate=1)
192
+
193
+ self.rebnconv5 = REBNCONV(mid_ch,mid_ch,dirate=2)
194
+
195
+ self.rebnconv4d = REBNCONV(mid_ch*2,mid_ch,dirate=1)
196
+ self.rebnconv3d = REBNCONV(mid_ch*2,mid_ch,dirate=1)
197
+ self.rebnconv2d = REBNCONV(mid_ch*2,mid_ch,dirate=1)
198
+ self.rebnconv1d = REBNCONV(mid_ch*2,out_ch,dirate=1)
199
+
200
+ def forward(self,x):
201
+
202
+ hx = x
203
+
204
+ hxin = self.rebnconvin(hx)
205
+
206
+ hx1 = self.rebnconv1(hxin)
207
+ hx = self.pool1(hx1)
208
+
209
+ hx2 = self.rebnconv2(hx)
210
+ hx = self.pool2(hx2)
211
+
212
+ hx3 = self.rebnconv3(hx)
213
+ hx = self.pool3(hx3)
214
+
215
+ hx4 = self.rebnconv4(hx)
216
+
217
+ hx5 = self.rebnconv5(hx4)
218
+
219
+ hx4d = self.rebnconv4d(torch.cat((hx5,hx4),1))
220
+ hx4dup = _upsample_like(hx4d,hx3)
221
+
222
+ hx3d = self.rebnconv3d(torch.cat((hx4dup,hx3),1))
223
+ hx3dup = _upsample_like(hx3d,hx2)
224
+
225
+ hx2d = self.rebnconv2d(torch.cat((hx3dup,hx2),1))
226
+ hx2dup = _upsample_like(hx2d,hx1)
227
+
228
+ hx1d = self.rebnconv1d(torch.cat((hx2dup,hx1),1))
229
+
230
+ return hx1d + hxin
231
+
232
+ ### RSU-4 ###
233
+ class RSU4(nn.Module):#UNet04DRES(nn.Module):
234
+
235
+ def __init__(self, in_ch=3, mid_ch=12, out_ch=3):
236
+ super(RSU4,self).__init__()
237
+
238
+ self.rebnconvin = REBNCONV(in_ch,out_ch,dirate=1)
239
+
240
+ self.rebnconv1 = REBNCONV(out_ch,mid_ch,dirate=1)
241
+ self.pool1 = nn.MaxPool2d(2,stride=2,ceil_mode=True)
242
+
243
+ self.rebnconv2 = REBNCONV(mid_ch,mid_ch,dirate=1)
244
+ self.pool2 = nn.MaxPool2d(2,stride=2,ceil_mode=True)
245
+
246
+ self.rebnconv3 = REBNCONV(mid_ch,mid_ch,dirate=1)
247
+
248
+ self.rebnconv4 = REBNCONV(mid_ch,mid_ch,dirate=2)
249
+
250
+ self.rebnconv3d = REBNCONV(mid_ch*2,mid_ch,dirate=1)
251
+ self.rebnconv2d = REBNCONV(mid_ch*2,mid_ch,dirate=1)
252
+ self.rebnconv1d = REBNCONV(mid_ch*2,out_ch,dirate=1)
253
+
254
+ def forward(self,x):
255
+
256
+ hx = x
257
+
258
+ hxin = self.rebnconvin(hx)
259
+
260
+ hx1 = self.rebnconv1(hxin)
261
+ hx = self.pool1(hx1)
262
+
263
+ hx2 = self.rebnconv2(hx)
264
+ hx = self.pool2(hx2)
265
+
266
+ hx3 = self.rebnconv3(hx)
267
+
268
+ hx4 = self.rebnconv4(hx3)
269
+
270
+ hx3d = self.rebnconv3d(torch.cat((hx4,hx3),1))
271
+ hx3dup = _upsample_like(hx3d,hx2)
272
+
273
+ hx2d = self.rebnconv2d(torch.cat((hx3dup,hx2),1))
274
+ hx2dup = _upsample_like(hx2d,hx1)
275
+
276
+ hx1d = self.rebnconv1d(torch.cat((hx2dup,hx1),1))
277
+
278
+ return hx1d + hxin
279
+
280
+ ### RSU-4F ###
281
+ class RSU4F(nn.Module):#UNet04FRES(nn.Module):
282
+
283
+ def __init__(self, in_ch=3, mid_ch=12, out_ch=3):
284
+ super(RSU4F,self).__init__()
285
+
286
+ self.rebnconvin = REBNCONV(in_ch,out_ch,dirate=1)
287
+
288
+ self.rebnconv1 = REBNCONV(out_ch,mid_ch,dirate=1)
289
+ self.rebnconv2 = REBNCONV(mid_ch,mid_ch,dirate=2)
290
+ self.rebnconv3 = REBNCONV(mid_ch,mid_ch,dirate=4)
291
+
292
+ self.rebnconv4 = REBNCONV(mid_ch,mid_ch,dirate=8)
293
+
294
+ self.rebnconv3d = REBNCONV(mid_ch*2,mid_ch,dirate=4)
295
+ self.rebnconv2d = REBNCONV(mid_ch*2,mid_ch,dirate=2)
296
+ self.rebnconv1d = REBNCONV(mid_ch*2,out_ch,dirate=1)
297
+
298
+ def forward(self,x):
299
+
300
+ hx = x
301
+
302
+ hxin = self.rebnconvin(hx)
303
+
304
+ hx1 = self.rebnconv1(hxin)
305
+ hx2 = self.rebnconv2(hx1)
306
+ hx3 = self.rebnconv3(hx2)
307
+
308
+ hx4 = self.rebnconv4(hx3)
309
+
310
+ hx3d = self.rebnconv3d(torch.cat((hx4,hx3),1))
311
+ hx2d = self.rebnconv2d(torch.cat((hx3d,hx2),1))
312
+ hx1d = self.rebnconv1d(torch.cat((hx2d,hx1),1))
313
+
314
+ return hx1d + hxin
315
+
316
+
317
+ ##### U^2-Net ####
318
+ class U2NET(nn.Module):
319
+
320
+ def __init__(self,in_ch=3,out_ch=1):
321
+ super(U2NET,self).__init__()
322
+
323
+ self.stage1 = RSU7(in_ch,32,64)
324
+ self.pool12 = nn.MaxPool2d(2,stride=2,ceil_mode=True)
325
+
326
+ self.stage2 = RSU6(64,32,128)
327
+ self.pool23 = nn.MaxPool2d(2,stride=2,ceil_mode=True)
328
+
329
+ self.stage3 = RSU5(128,64,256)
330
+ self.pool34 = nn.MaxPool2d(2,stride=2,ceil_mode=True)
331
+
332
+ self.stage4 = RSU4(256,128,512)
333
+ self.pool45 = nn.MaxPool2d(2,stride=2,ceil_mode=True)
334
+
335
+ self.stage5 = RSU4F(512,256,512)
336
+ self.pool56 = nn.MaxPool2d(2,stride=2,ceil_mode=True)
337
+
338
+ self.stage6 = RSU4F(512,256,512)
339
+
340
+ # decoder
341
+ self.stage5d = RSU4F(1024,256,512)
342
+ self.stage4d = RSU4(1024,128,256)
343
+ self.stage3d = RSU5(512,64,128)
344
+ self.stage2d = RSU6(256,32,64)
345
+ self.stage1d = RSU7(128,16,64)
346
+
347
+ self.side1 = nn.Conv2d(64,out_ch,3,padding=1)
348
+ self.side2 = nn.Conv2d(64,out_ch,3,padding=1)
349
+ self.side3 = nn.Conv2d(128,out_ch,3,padding=1)
350
+ self.side4 = nn.Conv2d(256,out_ch,3,padding=1)
351
+ self.side5 = nn.Conv2d(512,out_ch,3,padding=1)
352
+ self.side6 = nn.Conv2d(512,out_ch,3,padding=1)
353
+
354
+ self.outconv = nn.Conv2d(6*out_ch,out_ch,1)
355
+
356
+ def forward(self,x):
357
+
358
+ hx = x
359
+
360
+ #stage 1
361
+ hx1 = self.stage1(hx)
362
+ hx = self.pool12(hx1)
363
+
364
+ #stage 2
365
+ hx2 = self.stage2(hx)
366
+ hx = self.pool23(hx2)
367
+
368
+ #stage 3
369
+ hx3 = self.stage3(hx)
370
+ hx = self.pool34(hx3)
371
+
372
+ #stage 4
373
+ hx4 = self.stage4(hx)
374
+ hx = self.pool45(hx4)
375
+
376
+ #stage 5
377
+ hx5 = self.stage5(hx)
378
+ hx = self.pool56(hx5)
379
+
380
+ #stage 6
381
+ hx6 = self.stage6(hx)
382
+ hx6up = _upsample_like(hx6,hx5)
383
+
384
+ #-------------------- decoder --------------------
385
+ hx5d = self.stage5d(torch.cat((hx6up,hx5),1))
386
+ hx5dup = _upsample_like(hx5d,hx4)
387
+
388
+ hx4d = self.stage4d(torch.cat((hx5dup,hx4),1))
389
+ hx4dup = _upsample_like(hx4d,hx3)
390
+
391
+ hx3d = self.stage3d(torch.cat((hx4dup,hx3),1))
392
+ hx3dup = _upsample_like(hx3d,hx2)
393
+
394
+ hx2d = self.stage2d(torch.cat((hx3dup,hx2),1))
395
+ hx2dup = _upsample_like(hx2d,hx1)
396
+
397
+ hx1d = self.stage1d(torch.cat((hx2dup,hx1),1))
398
+
399
+
400
+ #side output
401
+ d1 = self.side1(hx1d)
402
+
403
+ d2 = self.side2(hx2d)
404
+ d2 = _upsample_like(d2,d1)
405
+
406
+ d3 = self.side3(hx3d)
407
+ d3 = _upsample_like(d3,d1)
408
+
409
+ d4 = self.side4(hx4d)
410
+ d4 = _upsample_like(d4,d1)
411
+
412
+ d5 = self.side5(hx5d)
413
+ d5 = _upsample_like(d5,d1)
414
+
415
+ d6 = self.side6(hx6)
416
+ d6 = _upsample_like(d6,d1)
417
+
418
+ d0 = self.outconv(torch.cat((d1,d2,d3,d4,d5,d6),1))
419
+
420
+ return F.sigmoid(d0), F.sigmoid(d1), F.sigmoid(d2), F.sigmoid(d3), F.sigmoid(d4), F.sigmoid(d5), F.sigmoid(d6)
421
+
422
+ ### U^2-Net small ###
423
+ class U2NETP(nn.Module):
424
+
425
+ def __init__(self,in_ch=3,out_ch=1):
426
+ super(U2NETP,self).__init__()
427
+
428
+ self.stage1 = RSU7(in_ch,16,64)
429
+ self.pool12 = nn.MaxPool2d(2,stride=2,ceil_mode=True)
430
+
431
+ self.stage2 = RSU6(64,16,64)
432
+ self.pool23 = nn.MaxPool2d(2,stride=2,ceil_mode=True)
433
+
434
+ self.stage3 = RSU5(64,16,64)
435
+ self.pool34 = nn.MaxPool2d(2,stride=2,ceil_mode=True)
436
+
437
+ self.stage4 = RSU4(64,16,64)
438
+ self.pool45 = nn.MaxPool2d(2,stride=2,ceil_mode=True)
439
+
440
+ self.stage5 = RSU4F(64,16,64)
441
+ self.pool56 = nn.MaxPool2d(2,stride=2,ceil_mode=True)
442
+
443
+ self.stage6 = RSU4F(64,16,64)
444
+
445
+ # decoder
446
+ self.stage5d = RSU4F(128,16,64)
447
+ self.stage4d = RSU4(128,16,64)
448
+ self.stage3d = RSU5(128,16,64)
449
+ self.stage2d = RSU6(128,16,64)
450
+ self.stage1d = RSU7(128,16,64)
451
+
452
+ self.side1 = nn.Conv2d(64,out_ch,3,padding=1)
453
+ self.side2 = nn.Conv2d(64,out_ch,3,padding=1)
454
+ self.side3 = nn.Conv2d(64,out_ch,3,padding=1)
455
+ self.side4 = nn.Conv2d(64,out_ch,3,padding=1)
456
+ self.side5 = nn.Conv2d(64,out_ch,3,padding=1)
457
+ self.side6 = nn.Conv2d(64,out_ch,3,padding=1)
458
+
459
+ self.outconv = nn.Conv2d(6*out_ch,out_ch,1)
460
+
461
+ def forward(self,x):
462
+
463
+ hx = x
464
+
465
+ #stage 1
466
+ hx1 = self.stage1(hx)
467
+ hx = self.pool12(hx1)
468
+
469
+ #stage 2
470
+ hx2 = self.stage2(hx)
471
+ hx = self.pool23(hx2)
472
+
473
+ #stage 3
474
+ hx3 = self.stage3(hx)
475
+ hx = self.pool34(hx3)
476
+
477
+ #stage 4
478
+ hx4 = self.stage4(hx)
479
+ hx = self.pool45(hx4)
480
+
481
+ #stage 5
482
+ hx5 = self.stage5(hx)
483
+ hx = self.pool56(hx5)
484
+
485
+ #stage 6
486
+ hx6 = self.stage6(hx)
487
+ hx6up = _upsample_like(hx6,hx5)
488
+
489
+ #decoder
490
+ hx5d = self.stage5d(torch.cat((hx6up,hx5),1))
491
+ hx5dup = _upsample_like(hx5d,hx4)
492
+
493
+ hx4d = self.stage4d(torch.cat((hx5dup,hx4),1))
494
+ hx4dup = _upsample_like(hx4d,hx3)
495
+
496
+ hx3d = self.stage3d(torch.cat((hx4dup,hx3),1))
497
+ hx3dup = _upsample_like(hx3d,hx2)
498
+
499
+ hx2d = self.stage2d(torch.cat((hx3dup,hx2),1))
500
+ hx2dup = _upsample_like(hx2d,hx1)
501
+
502
+ hx1d = self.stage1d(torch.cat((hx2dup,hx1),1))
503
+
504
+
505
+ #side output
506
+ d1 = self.side1(hx1d)
507
+
508
+ d2 = self.side2(hx2d)
509
+ d2 = _upsample_like(d2,d1)
510
+
511
+ d3 = self.side3(hx3d)
512
+ d3 = _upsample_like(d3,d1)
513
+
514
+ d4 = self.side4(hx4d)
515
+ d4 = _upsample_like(d4,d1)
516
+
517
+ d5 = self.side5(hx5d)
518
+ d5 = _upsample_like(d5,d1)
519
+
520
+ d6 = self.side6(hx6)
521
+ d6 = _upsample_like(d6,d1)
522
+
523
+ d0 = self.outconv(torch.cat((d1,d2,d3,d4,d5,d6),1))
524
+
525
+ return F.sigmoid(d0), F.sigmoid(d1), F.sigmoid(d2), F.sigmoid(d3), F.sigmoid(d4), F.sigmoid(d5), F.sigmoid(d6)
U2Net_/model/u2net_refactor.py ADDED
@@ -0,0 +1,168 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+
4
+ import math
5
+
6
+ __all__ = ['U2NET_full', 'U2NET_lite']
7
+
8
+
9
+ def _upsample_like(x, size):
10
+ return nn.Upsample(size=size, mode='bilinear', align_corners=False)(x)
11
+
12
+
13
+ def _size_map(x, height):
14
+ # {height: size} for Upsample
15
+ size = list(x.shape[-2:])
16
+ sizes = {}
17
+ for h in range(1, height):
18
+ sizes[h] = size
19
+ size = [math.ceil(w / 2) for w in size]
20
+ return sizes
21
+
22
+
23
+ class REBNCONV(nn.Module):
24
+ def __init__(self, in_ch=3, out_ch=3, dilate=1):
25
+ super(REBNCONV, self).__init__()
26
+
27
+ self.conv_s1 = nn.Conv2d(in_ch, out_ch, 3, padding=1 * dilate, dilation=1 * dilate)
28
+ self.bn_s1 = nn.BatchNorm2d(out_ch)
29
+ self.relu_s1 = nn.ReLU(inplace=True)
30
+
31
+ def forward(self, x):
32
+ return self.relu_s1(self.bn_s1(self.conv_s1(x)))
33
+
34
+
35
+ class RSU(nn.Module):
36
+ def __init__(self, name, height, in_ch, mid_ch, out_ch, dilated=False):
37
+ super(RSU, self).__init__()
38
+ self.name = name
39
+ self.height = height
40
+ self.dilated = dilated
41
+ self._make_layers(height, in_ch, mid_ch, out_ch, dilated)
42
+
43
+ def forward(self, x):
44
+ sizes = _size_map(x, self.height)
45
+ x = self.rebnconvin(x)
46
+
47
+ # U-Net like symmetric encoder-decoder structure
48
+ def unet(x, height=1):
49
+ if height < self.height:
50
+ x1 = getattr(self, f'rebnconv{height}')(x)
51
+ if not self.dilated and height < self.height - 1:
52
+ x2 = unet(getattr(self, 'downsample')(x1), height + 1)
53
+ else:
54
+ x2 = unet(x1, height + 1)
55
+
56
+ x = getattr(self, f'rebnconv{height}d')(torch.cat((x2, x1), 1))
57
+ return _upsample_like(x, sizes[height - 1]) if not self.dilated and height > 1 else x
58
+ else:
59
+ return getattr(self, f'rebnconv{height}')(x)
60
+
61
+ return x + unet(x)
62
+
63
+ def _make_layers(self, height, in_ch, mid_ch, out_ch, dilated=False):
64
+ self.add_module('rebnconvin', REBNCONV(in_ch, out_ch))
65
+ self.add_module('downsample', nn.MaxPool2d(2, stride=2, ceil_mode=True))
66
+
67
+ self.add_module(f'rebnconv1', REBNCONV(out_ch, mid_ch))
68
+ self.add_module(f'rebnconv1d', REBNCONV(mid_ch * 2, out_ch))
69
+
70
+ for i in range(2, height):
71
+ dilate = 1 if not dilated else 2 ** (i - 1)
72
+ self.add_module(f'rebnconv{i}', REBNCONV(mid_ch, mid_ch, dilate=dilate))
73
+ self.add_module(f'rebnconv{i}d', REBNCONV(mid_ch * 2, mid_ch, dilate=dilate))
74
+
75
+ dilate = 2 if not dilated else 2 ** (height - 1)
76
+ self.add_module(f'rebnconv{height}', REBNCONV(mid_ch, mid_ch, dilate=dilate))
77
+
78
+
79
+ class U2NET(nn.Module):
80
+ def __init__(self, cfgs, out_ch):
81
+ super(U2NET, self).__init__()
82
+ self.out_ch = out_ch
83
+ self._make_layers(cfgs)
84
+
85
+ def forward(self, x):
86
+ sizes = _size_map(x, self.height)
87
+ maps = [] # storage for maps
88
+
89
+ # side saliency map
90
+ def unet(x, height=1):
91
+ if height < 6:
92
+ x1 = getattr(self, f'stage{height}')(x)
93
+ x2 = unet(getattr(self, 'downsample')(x1), height + 1)
94
+ x = getattr(self, f'stage{height}d')(torch.cat((x2, x1), 1))
95
+ side(x, height)
96
+ return _upsample_like(x, sizes[height - 1]) if height > 1 else x
97
+ else:
98
+ x = getattr(self, f'stage{height}')(x)
99
+ side(x, height)
100
+ return _upsample_like(x, sizes[height - 1])
101
+
102
+ def side(x, h):
103
+ # side output saliency map (before sigmoid)
104
+ x = getattr(self, f'side{h}')(x)
105
+ x = _upsample_like(x, sizes[1])
106
+ maps.append(x)
107
+
108
+ def fuse():
109
+ # fuse saliency probability maps
110
+ maps.reverse()
111
+ x = torch.cat(maps, 1)
112
+ x = getattr(self, 'outconv')(x)
113
+ maps.insert(0, x)
114
+ return [torch.sigmoid(x) for x in maps]
115
+
116
+ unet(x)
117
+ maps = fuse()
118
+ return maps
119
+
120
+ def _make_layers(self, cfgs):
121
+ self.height = int((len(cfgs) + 1) / 2)
122
+ self.add_module('downsample', nn.MaxPool2d(2, stride=2, ceil_mode=True))
123
+ for k, v in cfgs.items():
124
+ # build rsu block
125
+ self.add_module(k, RSU(v[0], *v[1]))
126
+ if v[2] > 0:
127
+ # build side layer
128
+ self.add_module(f'side{v[0][-1]}', nn.Conv2d(v[2], self.out_ch, 3, padding=1))
129
+ # build fuse layer
130
+ self.add_module('outconv', nn.Conv2d(int(self.height * self.out_ch), self.out_ch, 1))
131
+
132
+
133
+ def U2NET_full():
134
+ full = {
135
+ # cfgs for building RSUs and sides
136
+ # {stage : [name, (height(L), in_ch, mid_ch, out_ch, dilated), side]}
137
+ 'stage1': ['En_1', (7, 3, 32, 64), -1],
138
+ 'stage2': ['En_2', (6, 64, 32, 128), -1],
139
+ 'stage3': ['En_3', (5, 128, 64, 256), -1],
140
+ 'stage4': ['En_4', (4, 256, 128, 512), -1],
141
+ 'stage5': ['En_5', (4, 512, 256, 512, True), -1],
142
+ 'stage6': ['En_6', (4, 512, 256, 512, True), 512],
143
+ 'stage5d': ['De_5', (4, 1024, 256, 512, True), 512],
144
+ 'stage4d': ['De_4', (4, 1024, 128, 256), 256],
145
+ 'stage3d': ['De_3', (5, 512, 64, 128), 128],
146
+ 'stage2d': ['De_2', (6, 256, 32, 64), 64],
147
+ 'stage1d': ['De_1', (7, 128, 16, 64), 64],
148
+ }
149
+ return U2NET(cfgs=full, out_ch=1)
150
+
151
+
152
+ def U2NET_lite():
153
+ lite = {
154
+ # cfgs for building RSUs and sides
155
+ # {stage : [name, (height(L), in_ch, mid_ch, out_ch, dilated), side]}
156
+ 'stage1': ['En_1', (7, 3, 16, 64), -1],
157
+ 'stage2': ['En_2', (6, 64, 16, 64), -1],
158
+ 'stage3': ['En_3', (5, 64, 16, 64), -1],
159
+ 'stage4': ['En_4', (4, 64, 16, 64), -1],
160
+ 'stage5': ['En_5', (4, 64, 16, 64, True), -1],
161
+ 'stage6': ['En_6', (4, 64, 16, 64, True), 64],
162
+ 'stage5d': ['De_5', (4, 128, 16, 64, True), 64],
163
+ 'stage4d': ['De_4', (4, 128, 16, 64), 64],
164
+ 'stage3d': ['De_3', (5, 128, 16, 64), 64],
165
+ 'stage2d': ['De_2', (6, 128, 16, 64), 64],
166
+ 'stage1d': ['De_1', (7, 128, 16, 64), 64],
167
+ }
168
+ return U2NET(cfgs=lite, out_ch=1)
U2Net_/saved_models/face_detection_cv2/haarcascade_frontalface_default.xml ADDED
The diff for this file is too large to render. See raw diff
 
cog.yaml ADDED
@@ -0,0 +1,54 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Configuration for Cog ⚙️
2
+ # Reference: https://github.com/replicate/cog/blob/main/docs/yaml.md
3
+
4
+ build:
5
+ # set to true if your model requires a GPU
6
+ gpu: true
7
+ cuda: "10.1"
8
+
9
+ # a list of ubuntu apt packages to install
10
+ system_packages:
11
+ #- "libopenmpi-dev"
12
+ - "libgl1-mesa-glx"
13
+ - "libglib2.0-0"
14
+ # - "cmake-"
15
+
16
+ # python version in the form '3.8' or '3.8.12'
17
+ python_version: "3.7.9"
18
+
19
+ # a list of packages in the format <package-name>==<version>
20
+ python_packages:
21
+ # cmake==3.21.2
22
+ # - "pip==21.2.2"
23
+ - "cmake==3.14.3"
24
+ - "torch==1.7.1"
25
+ - "torchvision==0.8.2"
26
+ - "numpy==1.19.2"
27
+ - "ipython==7.21.0"
28
+ - "Pillow==8.3.1"
29
+ - "svgwrite==1.4.1"
30
+ - "svgpathtools==1.4.1"
31
+ - "cssutils==2.3.0"
32
+ - "numba==0.55.1"
33
+ - "torch-tools==0.1.5"
34
+ - "visdom==0.1.8.9"
35
+ - "ftfy==6.1.1"
36
+ - "regex==2021.8.28"
37
+ - "tqdm==4.62.3"
38
+ - "scikit-image==0.18.3"
39
+ - "gdown==4.4.0"
40
+ - "wandb==0.12.0"
41
+ - "tensorflow-gpu==1.15.2"
42
+
43
+ # commands run after the environment is setup
44
+ run:
45
+ # - /root/.pyenv/versions/3.7.9/bin/python3.7 -m pip install --upgrade pip
46
+ - export PYTHONPATH="/diffvg/build/lib.linux-x86_64-3.7"
47
+ - git clone https://github.com/BachiLi/diffvg && cd diffvg && git submodule update --init --recursive && CMAKE_PREFIX_PATH=$(pyenv prefix) DIFFVG_CUDA=1 python setup.py install
48
+ - pip install git+https://github.com/openai/CLIP.git --no-deps
49
+ - gdown "https://drive.google.com/uc?id=1ao1ovG1Qtx4b7EoskHXmi2E9rp5CHLcZ" -O "/src/U2Net_/saved_models/"
50
+ - "echo env is ready!"
51
+ - "echo another command if needed"
52
+
53
+ # predict.py defines how predictions are run on your model
54
+ predict: "predict.py:Predictor"
config.py ADDED
@@ -0,0 +1,144 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import os
3
+ import random
4
+
5
+ import numpy as np
6
+ import pydiffvg
7
+ import torch
8
+ import wandb
9
+
10
+
11
+ def set_seed(seed):
12
+ random.seed(seed)
13
+ np.random.seed(seed)
14
+ os.environ['PYTHONHASHSEED'] = str(seed)
15
+ torch.manual_seed(seed)
16
+ torch.cuda.manual_seed(seed)
17
+ torch.cuda.manual_seed_all(seed)
18
+
19
+
20
+ def parse_arguments():
21
+ parser = argparse.ArgumentParser()
22
+ # =================================
23
+ # ============ general ============
24
+ # =================================
25
+ parser.add_argument("target", help="target image path")
26
+ parser.add_argument("--output_dir", type=str,
27
+ help="directory to save the output images and loss")
28
+ parser.add_argument("--path_svg", type=str, default="none",
29
+ help="if you want to load an svg file and train from it")
30
+ parser.add_argument("--use_gpu", type=int, default=0)
31
+ parser.add_argument("--seed", type=int, default=0)
32
+ parser.add_argument("--mask_object", type=int, default=0)
33
+ parser.add_argument("--fix_scale", type=int, default=0)
34
+ parser.add_argument("--display_logs", type=int, default=0)
35
+ parser.add_argument("--display", type=int, default=0)
36
+
37
+ # =================================
38
+ # ============ wandb ============
39
+ # =================================
40
+ parser.add_argument("--use_wandb", type=int, default=0)
41
+ parser.add_argument("--wandb_user", type=str, default="yael-vinker")
42
+ parser.add_argument("--wandb_name", type=str, default="test")
43
+ parser.add_argument("--wandb_project_name", type=str, default="none")
44
+
45
+ # =================================
46
+ # =========== training ============
47
+ # =================================
48
+ parser.add_argument("--num_iter", type=int, default=500,
49
+ help="number of optimization iterations")
50
+ parser.add_argument("--num_stages", type=int, default=1,
51
+ help="training stages, you can train x strokes, then freeze them and train another x strokes etc.")
52
+ parser.add_argument("--lr_scheduler", type=int, default=0)
53
+ parser.add_argument("--lr", type=float, default=1.0)
54
+ parser.add_argument("--color_lr", type=float, default=0.01)
55
+ parser.add_argument("--color_vars_threshold", type=float, default=0.0)
56
+ parser.add_argument("--batch_size", type=int, default=1,
57
+ help="for optimization it's only one image")
58
+ parser.add_argument("--save_interval", type=int, default=10)
59
+ parser.add_argument("--eval_interval", type=int, default=10)
60
+ parser.add_argument("--image_scale", type=int, default=224)
61
+
62
+ # =================================
63
+ # ======== strokes params =========
64
+ # =================================
65
+ parser.add_argument("--num_paths", type=int,
66
+ default=16, help="number of strokes")
67
+ parser.add_argument("--width", type=float,
68
+ default=1.5, help="stroke width")
69
+ parser.add_argument("--control_points_per_seg", type=int, default=4)
70
+ parser.add_argument("--num_segments", type=int, default=1,
71
+ help="number of segments for each stroke, each stroke is a bezier curve with 4 control points")
72
+ parser.add_argument("--attention_init", type=int, default=1,
73
+ help="if True, use the attention heads of Dino model to set the location of the initial strokes")
74
+ parser.add_argument("--saliency_model", type=str, default="clip")
75
+ parser.add_argument("--saliency_clip_model", type=str, default="ViT-B/32")
76
+ parser.add_argument("--xdog_intersec", type=int, default=1)
77
+ parser.add_argument("--mask_object_attention", type=int, default=0)
78
+ parser.add_argument("--softmax_temp", type=float, default=0.3)
79
+
80
+ # =================================
81
+ # ============= loss ==============
82
+ # =================================
83
+ parser.add_argument("--percep_loss", type=str, default="none",
84
+ help="the type of perceptual loss to be used (L2/LPIPS/none)")
85
+ parser.add_argument("--perceptual_weight", type=float, default=0,
86
+ help="weight the perceptual loss")
87
+ parser.add_argument("--train_with_clip", type=int, default=0)
88
+ parser.add_argument("--clip_weight", type=float, default=0)
89
+ parser.add_argument("--start_clip", type=int, default=0)
90
+ parser.add_argument("--num_aug_clip", type=int, default=4)
91
+ parser.add_argument("--include_target_in_aug", type=int, default=0)
92
+ parser.add_argument("--augment_both", type=int, default=1,
93
+ help="if you want to apply the affine augmentation to both the sketch and image")
94
+ parser.add_argument("--augemntations", type=str, default="affine",
95
+ help="can be any combination of: 'affine_noise_eraserchunks_eraser_press'")
96
+ parser.add_argument("--noise_thresh", type=float, default=0.5)
97
+ parser.add_argument("--aug_scale_min", type=float, default=0.7)
98
+ parser.add_argument("--force_sparse", type=float, default=0,
99
+ help="if True, use L1 regularization on stroke's opacity to encourage small number of strokes")
100
+ parser.add_argument("--clip_conv_loss", type=float, default=1)
101
+ parser.add_argument("--clip_conv_loss_type", type=str, default="L2")
102
+ parser.add_argument("--clip_conv_layer_weights",
103
+ type=str, default="0,0,1.0,1.0,0")
104
+ parser.add_argument("--clip_model_name", type=str, default="RN101")
105
+ parser.add_argument("--clip_fc_loss_weight", type=float, default=0.1)
106
+ parser.add_argument("--clip_text_guide", type=float, default=0)
107
+ parser.add_argument("--text_target", type=str, default="none")
108
+
109
+ args = parser.parse_args()
110
+ set_seed(args.seed)
111
+
112
+ args.clip_conv_layer_weights = [
113
+ float(item) for item in args.clip_conv_layer_weights.split(',')]
114
+
115
+ args.output_dir = os.path.join(args.output_dir, args.wandb_name)
116
+ if not os.path.exists(args.output_dir):
117
+ os.mkdir(args.output_dir)
118
+
119
+ jpg_logs_dir = f"{args.output_dir}/jpg_logs"
120
+ svg_logs_dir = f"{args.output_dir}/svg_logs"
121
+ if not os.path.exists(jpg_logs_dir):
122
+ os.mkdir(jpg_logs_dir)
123
+ if not os.path.exists(svg_logs_dir):
124
+ os.mkdir(svg_logs_dir)
125
+
126
+ if args.use_wandb:
127
+ wandb.init(project=args.wandb_project_name, entity=args.wandb_user,
128
+ config=args, name=args.wandb_name, id=wandb.util.generate_id())
129
+
130
+ if args.use_gpu:
131
+ args.device = torch.device("cuda" if (
132
+ torch.cuda.is_available() and torch.cuda.device_count() > 0) else "cpu")
133
+ else:
134
+ args.device = torch.device("cpu")
135
+ pydiffvg.set_use_gpu(torch.cuda.is_available() and args.use_gpu)
136
+ pydiffvg.set_device(args.device)
137
+ return args
138
+
139
+
140
+ if __name__ == "__main__":
141
+ # for cog predict
142
+ args = parse_arguments()
143
+ final_config = vars(args)
144
+ np.save(f"{args.output_dir}/config_init.npy", final_config)
display_results.py ADDED
@@ -0,0 +1,94 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import os
3
+ import re
4
+
5
+ import imageio
6
+ import matplotlib.pyplot as plt
7
+ import moviepy.editor as mvp
8
+ import numpy as np
9
+ import pydiffvg
10
+ import torch
11
+ from IPython.display import Image as Image_colab
12
+ from IPython.display import display, SVG
13
+ from PIL import Image
14
+
15
+ parser = argparse.ArgumentParser()
16
+ parser.add_argument("--target_file", type=str,
17
+ help="target image file, located in <target_images>")
18
+ parser.add_argument("--num_strokes", type=int)
19
+ args = parser.parse_args()
20
+
21
+
22
+ def read_svg(path_svg, multiply=False):
23
+ device = torch.device("cuda" if (
24
+ torch.cuda.is_available() and torch.cuda.device_count() > 0) else "cpu")
25
+ canvas_width, canvas_height, shapes, shape_groups = pydiffvg.svg_to_scene(
26
+ path_svg)
27
+ if multiply:
28
+ canvas_width *= 2
29
+ canvas_height *= 2
30
+ for path in shapes:
31
+ path.points *= 2
32
+ path.stroke_width *= 2
33
+ _render = pydiffvg.RenderFunction.apply
34
+ scene_args = pydiffvg.RenderFunction.serialize_scene(
35
+ canvas_width, canvas_height, shapes, shape_groups)
36
+ img = _render(canvas_width, # width
37
+ canvas_height, # height
38
+ 2, # num_samples_x
39
+ 2, # num_samples_y
40
+ 0, # seed
41
+ None,
42
+ *scene_args)
43
+ img = img[:, :, 3:4] * img[:, :, :3] + \
44
+ torch.ones(img.shape[0], img.shape[1], 3,
45
+ device=device) * (1 - img[:, :, 3:4])
46
+ img = img[:, :, :3]
47
+ return img
48
+
49
+
50
+ abs_path = os.path.abspath(os.getcwd())
51
+
52
+ result_path = f"{abs_path}/output_sketches/{os.path.splitext(args.target_file)[0]}"
53
+ svg_files = os.listdir(result_path)
54
+ svg_files = [f for f in svg_files if "best.svg" in f and f"{args.num_strokes}strokes" in f]
55
+ svg_output_path = f"{result_path}/{svg_files[0]}"
56
+
57
+ target_path = f"{svg_output_path[:-9]}/input.png"
58
+
59
+ sketch_res = read_svg(svg_output_path, multiply=True).cpu().numpy()
60
+ sketch_res = Image.fromarray((sketch_res * 255).astype('uint8'), 'RGB')
61
+
62
+ input_im = Image.open(target_path).resize((224,224))
63
+ display(input_im)
64
+ display(SVG(svg_output_path))
65
+
66
+ p = re.compile("_best")
67
+ best_sketch_dir = ""
68
+ for m in p.finditer(svg_files[0]):
69
+ best_sketch_dir += svg_files[0][0: m.start()]
70
+
71
+
72
+ sketches = []
73
+ cur_path = f"{result_path}/{best_sketch_dir}"
74
+ sketch_res.save(f"{cur_path}/final_sketch.png")
75
+ print(f"You can download the result sketch from {cur_path}/final_sketch.png")
76
+
77
+ if not os.path.exists(f"{cur_path}/svg_to_png"):
78
+ os.mkdir(f"{cur_path}/svg_to_png")
79
+ if os.path.exists(f"{cur_path}/config.npy"):
80
+ config = np.load(f"{cur_path}/config.npy", allow_pickle=True)[()]
81
+ inter = config["save_interval"]
82
+ loss_eval = np.array(config['loss_eval'])
83
+ inds = np.argsort(loss_eval)
84
+ intervals = list(range(0, (inds[0] + 1) * inter, inter))
85
+ for i_ in intervals:
86
+ path_svg = f"{cur_path}/svg_logs/svg_iter{i_}.svg"
87
+ sketch = read_svg(path_svg, multiply=True).cpu().numpy()
88
+ sketch = Image.fromarray((sketch * 255).astype('uint8'), 'RGB')
89
+ # print("{0}/iter_{1:04}.png".format(cur_path, int(i_)))
90
+ sketch.save("{0}/{1}/iter_{2:04}.png".format(cur_path, "svg_to_png", int(i_)))
91
+ sketches.append(sketch)
92
+ imageio.mimsave(f"{cur_path}/sketch.gif", sketches)
93
+
94
+ print(cur_path)
models/loss.py ADDED
@@ -0,0 +1,463 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import collections
3
+ import CLIP_.clip as clip
4
+ import torch
5
+ import torch.nn as nn
6
+ from torchvision import models, transforms
7
+
8
+
9
+ class Loss(nn.Module):
10
+ def __init__(self, args):
11
+ super(Loss, self).__init__()
12
+ self.args = args
13
+ self.percep_loss = args.percep_loss
14
+
15
+ self.train_with_clip = args.train_with_clip
16
+ self.clip_weight = args.clip_weight
17
+ self.start_clip = args.start_clip
18
+
19
+ self.clip_conv_loss = args.clip_conv_loss
20
+ self.clip_fc_loss_weight = args.clip_fc_loss_weight
21
+ self.clip_text_guide = args.clip_text_guide
22
+
23
+ self.losses_to_apply = self.get_losses_to_apply()
24
+
25
+ self.loss_mapper = \
26
+ {
27
+ "clip": CLIPLoss(args),
28
+ "clip_conv_loss": CLIPConvLoss(args)
29
+ }
30
+
31
+ def get_losses_to_apply(self):
32
+ losses_to_apply = []
33
+ if self.percep_loss != "none":
34
+ losses_to_apply.append(self.percep_loss)
35
+ if self.train_with_clip and self.start_clip == 0:
36
+ losses_to_apply.append("clip")
37
+ if self.clip_conv_loss:
38
+ losses_to_apply.append("clip_conv_loss")
39
+ if self.clip_text_guide:
40
+ losses_to_apply.append("clip_text")
41
+ return losses_to_apply
42
+
43
+ def update_losses_to_apply(self, epoch):
44
+ if "clip" not in self.losses_to_apply:
45
+ if self.train_with_clip:
46
+ if epoch > self.start_clip:
47
+ self.losses_to_apply.append("clip")
48
+
49
+ def forward(self, sketches, targets, color_parameters, renderer, epoch, points_optim=None, mode="train"):
50
+ loss = 0
51
+ self.update_losses_to_apply(epoch)
52
+
53
+ losses_dict = dict.fromkeys(
54
+ self.losses_to_apply, torch.tensor([0.0]).to(self.args.device))
55
+ loss_coeffs = dict.fromkeys(self.losses_to_apply, 1.0)
56
+ loss_coeffs["clip"] = self.clip_weight
57
+ loss_coeffs["clip_text"] = self.clip_text_guide
58
+
59
+ for loss_name in self.losses_to_apply:
60
+ if loss_name in ["clip_conv_loss"]:
61
+ conv_loss = self.loss_mapper[loss_name](
62
+ sketches, targets, mode)
63
+ for layer in conv_loss.keys():
64
+ losses_dict[layer] = conv_loss[layer]
65
+ elif loss_name == "l2":
66
+ losses_dict[loss_name] = self.loss_mapper[loss_name](
67
+ sketches, targets).mean()
68
+ else:
69
+ losses_dict[loss_name] = self.loss_mapper[loss_name](
70
+ sketches, targets, mode).mean()
71
+ # loss = loss + self.loss_mapper[loss_name](sketches, targets).mean() * loss_coeffs[loss_name]
72
+
73
+ for key in self.losses_to_apply:
74
+ # loss = loss + losses_dict[key] * loss_coeffs[key]
75
+ losses_dict[key] = losses_dict[key] * loss_coeffs[key]
76
+ # print(losses_dict)
77
+ return losses_dict
78
+
79
+
80
+ class CLIPLoss(torch.nn.Module):
81
+ def __init__(self, args):
82
+ super(CLIPLoss, self).__init__()
83
+
84
+ self.args = args
85
+ self.model, clip_preprocess = clip.load(
86
+ 'ViT-B/32', args.device, jit=False)
87
+ self.model.eval()
88
+ self.preprocess = transforms.Compose(
89
+ [clip_preprocess.transforms[-1]]) # clip normalisation
90
+ self.device = args.device
91
+ self.NUM_AUGS = args.num_aug_clip
92
+ augemntations = []
93
+ if "affine" in args.augemntations:
94
+ augemntations.append(transforms.RandomPerspective(
95
+ fill=0, p=1.0, distortion_scale=0.5))
96
+ augemntations.append(transforms.RandomResizedCrop(
97
+ 224, scale=(0.8, 0.8), ratio=(1.0, 1.0)))
98
+ augemntations.append(
99
+ transforms.Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)))
100
+ self.augment_trans = transforms.Compose(augemntations)
101
+
102
+ self.calc_target = True
103
+ self.include_target_in_aug = args.include_target_in_aug
104
+ self.counter = 0
105
+ self.augment_both = args.augment_both
106
+
107
+ def forward(self, sketches, targets, mode="train"):
108
+ if self.calc_target:
109
+ targets_ = self.preprocess(targets).to(self.device)
110
+ self.targets_features = self.model.encode_image(targets_).detach()
111
+ self.calc_target = False
112
+
113
+ if mode == "eval":
114
+ # for regular clip distance, no augmentations
115
+ with torch.no_grad():
116
+ sketches = self.preprocess(sketches).to(self.device)
117
+ sketches_features = self.model.encode_image(sketches)
118
+ return 1. - torch.cosine_similarity(sketches_features, self.targets_features)
119
+
120
+ loss_clip = 0
121
+ sketch_augs = []
122
+ img_augs = []
123
+ for n in range(self.NUM_AUGS):
124
+ augmented_pair = self.augment_trans(torch.cat([sketches, targets]))
125
+ sketch_augs.append(augmented_pair[0].unsqueeze(0))
126
+
127
+ sketch_batch = torch.cat(sketch_augs)
128
+ # sketch_utils.plot_batch(img_batch, sketch_batch, self.args, self.counter, use_wandb=False, title="fc_aug{}_iter{}_{}.jpg".format(1, self.counter, mode))
129
+ # if self.counter % 100 == 0:
130
+ # sketch_utils.plot_batch(img_batch, sketch_batch, self.args, self.counter, use_wandb=False, title="aug{}_iter{}_{}.jpg".format(1, self.counter, mode))
131
+
132
+ sketch_features = self.model.encode_image(sketch_batch)
133
+
134
+ for n in range(self.NUM_AUGS):
135
+ loss_clip += (1. - torch.cosine_similarity(
136
+ sketch_features[n:n+1], self.targets_features, dim=1))
137
+ self.counter += 1
138
+ return loss_clip
139
+ # return 1. - torch.cosine_similarity(sketches_features, self.targets_features)
140
+
141
+
142
+ class LPIPS(torch.nn.Module):
143
+ def __init__(self, pretrained=True, normalize=True, pre_relu=True, device=None):
144
+ """
145
+ Args:
146
+ pre_relu(bool): if True, selects features **before** reLU activations
147
+ """
148
+ super(LPIPS, self).__init__()
149
+ # VGG using perceptually-learned weights (LPIPS metric)
150
+ self.normalize = normalize
151
+ self.pretrained = pretrained
152
+ augemntations = []
153
+ augemntations.append(transforms.RandomPerspective(
154
+ fill=0, p=1.0, distortion_scale=0.5))
155
+ augemntations.append(transforms.RandomResizedCrop(
156
+ 224, scale=(0.8, 0.8), ratio=(1.0, 1.0)))
157
+ self.augment_trans = transforms.Compose(augemntations)
158
+ self.feature_extractor = LPIPS._FeatureExtractor(
159
+ pretrained, pre_relu).to(device)
160
+
161
+ def _l2_normalize_features(self, x, eps=1e-10):
162
+ nrm = torch.sqrt(torch.sum(x * x, dim=1, keepdim=True))
163
+ return x / (nrm + eps)
164
+
165
+ def forward(self, pred, target, mode="train"):
166
+ """Compare VGG features of two inputs."""
167
+
168
+ # Get VGG features
169
+
170
+ sketch_augs, img_augs = [pred], [target]
171
+ if mode == "train":
172
+ for n in range(4):
173
+ augmented_pair = self.augment_trans(torch.cat([pred, target]))
174
+ sketch_augs.append(augmented_pair[0].unsqueeze(0))
175
+ img_augs.append(augmented_pair[1].unsqueeze(0))
176
+
177
+ xs = torch.cat(sketch_augs, dim=0)
178
+ ys = torch.cat(img_augs, dim=0)
179
+
180
+ pred = self.feature_extractor(xs)
181
+ target = self.feature_extractor(ys)
182
+
183
+ # L2 normalize features
184
+ if self.normalize:
185
+ pred = [self._l2_normalize_features(f) for f in pred]
186
+ target = [self._l2_normalize_features(f) for f in target]
187
+
188
+ # TODO(mgharbi) Apply Richard's linear weights?
189
+
190
+ if self.normalize:
191
+ diffs = [torch.sum((p - t) ** 2, 1)
192
+ for (p, t) in zip(pred, target)]
193
+ else:
194
+ # mean instead of sum to avoid super high range
195
+ diffs = [torch.mean((p - t) ** 2, 1)
196
+ for (p, t) in zip(pred, target)]
197
+
198
+ # Spatial average
199
+ diffs = [diff.mean([1, 2]) for diff in diffs]
200
+
201
+ return sum(diffs)
202
+
203
+ class _FeatureExtractor(torch.nn.Module):
204
+ def __init__(self, pretrained, pre_relu):
205
+ super(LPIPS._FeatureExtractor, self).__init__()
206
+ vgg_pretrained = models.vgg16(pretrained=pretrained).features
207
+
208
+ self.breakpoints = [0, 4, 9, 16, 23, 30]
209
+ if pre_relu:
210
+ for i, _ in enumerate(self.breakpoints[1:]):
211
+ self.breakpoints[i + 1] -= 1
212
+
213
+ # Split at the maxpools
214
+ for i, b in enumerate(self.breakpoints[:-1]):
215
+ ops = torch.nn.Sequential()
216
+ for idx in range(b, self.breakpoints[i + 1]):
217
+ op = vgg_pretrained[idx]
218
+ ops.add_module(str(idx), op)
219
+ # print(ops)
220
+ self.add_module("group{}".format(i), ops)
221
+
222
+ # No gradients
223
+ for p in self.parameters():
224
+ p.requires_grad = False
225
+
226
+ # Torchvision's normalization: <https://github.com/pytorch/examples/blob/42e5b996718797e45c46a25c55b031e6768f8440/imagenet/main.py#L89-L101>
227
+ self.register_buffer("shift", torch.Tensor(
228
+ [0.485, 0.456, 0.406]).view(1, 3, 1, 1))
229
+ self.register_buffer("scale", torch.Tensor(
230
+ [0.229, 0.224, 0.225]).view(1, 3, 1, 1))
231
+
232
+ def forward(self, x):
233
+ feats = []
234
+ x = (x - self.shift) / self.scale
235
+ for idx in range(len(self.breakpoints) - 1):
236
+ m = getattr(self, "group{}".format(idx))
237
+ x = m(x)
238
+ feats.append(x)
239
+ return feats
240
+
241
+
242
+ class L2_(torch.nn.Module):
243
+ def __init__(self):
244
+ """
245
+ Args:
246
+ pre_relu(bool): if True, selects features **before** reLU activations
247
+ """
248
+ super(L2_, self).__init__()
249
+ # VGG using perceptually-learned weights (LPIPS metric)
250
+ augemntations = []
251
+ augemntations.append(transforms.RandomPerspective(
252
+ fill=0, p=1.0, distortion_scale=0.5))
253
+ augemntations.append(transforms.RandomResizedCrop(
254
+ 224, scale=(0.8, 0.8), ratio=(1.0, 1.0)))
255
+ augemntations.append(
256
+ transforms.Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)))
257
+ self.augment_trans = transforms.Compose(augemntations)
258
+ # LOG.warning("LPIPS is untested")
259
+
260
+ def forward(self, pred, target, mode="train"):
261
+ """Compare VGG features of two inputs."""
262
+
263
+ # Get VGG features
264
+
265
+ sketch_augs, img_augs = [pred], [target]
266
+ if mode == "train":
267
+ for n in range(4):
268
+ augmented_pair = self.augment_trans(torch.cat([pred, target]))
269
+ sketch_augs.append(augmented_pair[0].unsqueeze(0))
270
+ img_augs.append(augmented_pair[1].unsqueeze(0))
271
+
272
+ pred = torch.cat(sketch_augs, dim=0)
273
+ target = torch.cat(img_augs, dim=0)
274
+ diffs = [torch.square(p - t).mean() for (p, t) in zip(pred, target)]
275
+ return sum(diffs)
276
+
277
+
278
+ class CLIPVisualEncoder(nn.Module):
279
+ def __init__(self, clip_model):
280
+ super().__init__()
281
+ self.clip_model = clip_model
282
+ self.featuremaps = None
283
+
284
+ for i in range(12): # 12 resblocks in VIT visual transformer
285
+ self.clip_model.visual.transformer.resblocks[i].register_forward_hook(
286
+ self.make_hook(i))
287
+
288
+ def make_hook(self, name):
289
+ def hook(module, input, output):
290
+ if len(output.shape) == 3:
291
+ self.featuremaps[name] = output.permute(
292
+ 1, 0, 2) # LND -> NLD bs, smth, 768
293
+ else:
294
+ self.featuremaps[name] = output
295
+
296
+ return hook
297
+
298
+ def forward(self, x):
299
+ self.featuremaps = collections.OrderedDict()
300
+ fc_features = self.clip_model.encode_image(x).float()
301
+ featuremaps = [self.featuremaps[k] for k in range(12)]
302
+
303
+ return fc_features, featuremaps
304
+
305
+
306
+ def l2_layers(xs_conv_features, ys_conv_features, clip_model_name):
307
+ return [torch.square(x_conv - y_conv).mean() for x_conv, y_conv in
308
+ zip(xs_conv_features, ys_conv_features)]
309
+
310
+
311
+ def l1_layers(xs_conv_features, ys_conv_features, clip_model_name):
312
+ return [torch.abs(x_conv - y_conv).mean() for x_conv, y_conv in
313
+ zip(xs_conv_features, ys_conv_features)]
314
+
315
+
316
+ def cos_layers(xs_conv_features, ys_conv_features, clip_model_name):
317
+ if "RN" in clip_model_name:
318
+ return [torch.square(x_conv, y_conv, dim=1).mean() for x_conv, y_conv in
319
+ zip(xs_conv_features, ys_conv_features)]
320
+ return [(1 - torch.cosine_similarity(x_conv, y_conv, dim=1)).mean() for x_conv, y_conv in
321
+ zip(xs_conv_features, ys_conv_features)]
322
+
323
+
324
+ class CLIPConvLoss(torch.nn.Module):
325
+ def __init__(self, args):
326
+ super(CLIPConvLoss, self).__init__()
327
+ self.clip_model_name = args.clip_model_name
328
+ assert self.clip_model_name in [
329
+ "RN50",
330
+ "RN101",
331
+ "RN50x4",
332
+ "RN50x16",
333
+ "ViT-B/32",
334
+ "ViT-B/16",
335
+ ]
336
+
337
+ self.clip_conv_loss_type = args.clip_conv_loss_type
338
+ self.clip_fc_loss_type = "Cos" # args.clip_fc_loss_type
339
+ assert self.clip_conv_loss_type in [
340
+ "L2", "Cos", "L1",
341
+ ]
342
+ assert self.clip_fc_loss_type in [
343
+ "L2", "Cos", "L1",
344
+ ]
345
+
346
+ self.distance_metrics = \
347
+ {
348
+ "L2": l2_layers,
349
+ "L1": l1_layers,
350
+ "Cos": cos_layers
351
+ }
352
+
353
+ self.model, clip_preprocess = clip.load(
354
+ self.clip_model_name, args.device, jit=False)
355
+
356
+ if self.clip_model_name.startswith("ViT"):
357
+ self.visual_encoder = CLIPVisualEncoder(self.model)
358
+
359
+ else:
360
+ self.visual_model = self.model.visual
361
+ layers = list(self.model.visual.children())
362
+ init_layers = torch.nn.Sequential(*layers)[:8]
363
+ self.layer1 = layers[8]
364
+ self.layer2 = layers[9]
365
+ self.layer3 = layers[10]
366
+ self.layer4 = layers[11]
367
+ self.att_pool2d = layers[12]
368
+
369
+ self.args = args
370
+
371
+ self.img_size = clip_preprocess.transforms[1].size
372
+ self.model.eval()
373
+ self.target_transform = transforms.Compose([
374
+ transforms.ToTensor(),
375
+ ]) # clip normalisation
376
+ self.normalize_transform = transforms.Compose([
377
+ clip_preprocess.transforms[0], # Resize
378
+ clip_preprocess.transforms[1], # CenterCrop
379
+ clip_preprocess.transforms[-1], # Normalize
380
+ ])
381
+
382
+ self.model.eval()
383
+ self.device = args.device
384
+ self.num_augs = self.args.num_aug_clip
385
+
386
+ augemntations = []
387
+ if "affine" in args.augemntations:
388
+ augemntations.append(transforms.RandomPerspective(
389
+ fill=0, p=1.0, distortion_scale=0.5))
390
+ augemntations.append(transforms.RandomResizedCrop(
391
+ 224, scale=(0.8, 0.8), ratio=(1.0, 1.0)))
392
+ augemntations.append(
393
+ transforms.Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)))
394
+ self.augment_trans = transforms.Compose(augemntations)
395
+
396
+ self.clip_fc_layer_dims = None # self.args.clip_fc_layer_dims
397
+ self.clip_conv_layer_dims = None # self.args.clip_conv_layer_dims
398
+ self.clip_fc_loss_weight = args.clip_fc_loss_weight
399
+ self.counter = 0
400
+
401
+ def forward(self, sketch, target, mode="train"):
402
+ """
403
+ Parameters
404
+ ----------
405
+ sketch: Torch Tensor [1, C, H, W]
406
+ target: Torch Tensor [1, C, H, W]
407
+ """
408
+ # y = self.target_transform(target).to(self.args.device)
409
+ conv_loss_dict = {}
410
+ x = sketch.to(self.device)
411
+ y = target.to(self.device)
412
+ sketch_augs, img_augs = [self.normalize_transform(x)], [
413
+ self.normalize_transform(y)]
414
+ if mode == "train":
415
+ for n in range(self.num_augs):
416
+ augmented_pair = self.augment_trans(torch.cat([x, y]))
417
+ sketch_augs.append(augmented_pair[0].unsqueeze(0))
418
+ img_augs.append(augmented_pair[1].unsqueeze(0))
419
+
420
+ xs = torch.cat(sketch_augs, dim=0).to(self.device)
421
+ ys = torch.cat(img_augs, dim=0).to(self.device)
422
+
423
+ if self.clip_model_name.startswith("RN"):
424
+ xs_fc_features, xs_conv_features = self.forward_inspection_clip_resnet(
425
+ xs.contiguous())
426
+ ys_fc_features, ys_conv_features = self.forward_inspection_clip_resnet(
427
+ ys.detach())
428
+
429
+ else:
430
+ xs_fc_features, xs_conv_features = self.visual_encoder(xs)
431
+ ys_fc_features, ys_conv_features = self.visual_encoder(ys)
432
+
433
+ conv_loss = self.distance_metrics[self.clip_conv_loss_type](
434
+ xs_conv_features, ys_conv_features, self.clip_model_name)
435
+
436
+ for layer, w in enumerate(self.args.clip_conv_layer_weights):
437
+ if w:
438
+ conv_loss_dict[f"clip_conv_loss_layer{layer}"] = conv_loss[layer] * w
439
+
440
+ if self.clip_fc_loss_weight:
441
+ # fc distance is always cos
442
+ fc_loss = (1 - torch.cosine_similarity(xs_fc_features,
443
+ ys_fc_features, dim=1)).mean()
444
+ conv_loss_dict["fc"] = fc_loss * self.clip_fc_loss_weight
445
+
446
+ self.counter += 1
447
+ return conv_loss_dict
448
+
449
+ def forward_inspection_clip_resnet(self, x):
450
+ def stem(m, x):
451
+ for conv, bn in [(m.conv1, m.bn1), (m.conv2, m.bn2), (m.conv3, m.bn3)]:
452
+ x = m.relu(bn(conv(x)))
453
+ x = m.avgpool(x)
454
+ return x
455
+ x = x.type(self.visual_model.conv1.weight.dtype)
456
+ x = stem(self.visual_model, x)
457
+ x1 = self.layer1(x)
458
+ x2 = self.layer2(x1)
459
+ x3 = self.layer3(x2)
460
+ x4 = self.layer4(x3)
461
+ y = self.att_pool2d(x4)
462
+ return y, [x, x1, x2, x3, x4]
463
+
models/painter_params.py ADDED
@@ -0,0 +1,539 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import random
2
+ import CLIP_.clip as clip
3
+ import numpy as np
4
+ import pydiffvg
5
+ import sketch_utils as utils
6
+ import torch
7
+ import torch.nn as nn
8
+ import torch.nn.functional as F
9
+ from PIL import Image
10
+ from scipy.ndimage.filters import gaussian_filter
11
+ from skimage.color import rgb2gray
12
+ from skimage.filters import threshold_otsu
13
+ from torchvision import transforms
14
+
15
+
16
+ class Painter(torch.nn.Module):
17
+ def __init__(self, args,
18
+ num_strokes=4,
19
+ num_segments=4,
20
+ imsize=224,
21
+ device=None,
22
+ target_im=None,
23
+ mask=None):
24
+ super(Painter, self).__init__()
25
+
26
+ self.args = args
27
+ self.num_paths = num_strokes
28
+ self.num_segments = num_segments
29
+ self.width = args.width
30
+ self.control_points_per_seg = args.control_points_per_seg
31
+ self.opacity_optim = args.force_sparse
32
+ self.num_stages = args.num_stages
33
+ self.add_random_noise = "noise" in args.augemntations
34
+ self.noise_thresh = args.noise_thresh
35
+ self.softmax_temp = args.softmax_temp
36
+
37
+ self.shapes = []
38
+ self.shape_groups = []
39
+ self.device = device
40
+ self.canvas_width, self.canvas_height = imsize, imsize
41
+ self.points_vars = []
42
+ self.color_vars = []
43
+ self.color_vars_threshold = args.color_vars_threshold
44
+
45
+ self.path_svg = args.path_svg
46
+ self.strokes_per_stage = self.num_paths
47
+ self.optimize_flag = []
48
+
49
+ # attention related for strokes initialisation
50
+ self.attention_init = args.attention_init
51
+ self.target_path = args.target
52
+ self.saliency_model = args.saliency_model
53
+ self.xdog_intersec = args.xdog_intersec
54
+ self.mask_object = args.mask_object_attention
55
+
56
+ self.text_target = args.text_target # for clip gradients
57
+ self.saliency_clip_model = args.saliency_clip_model
58
+ self.define_attention_input(target_im)
59
+ self.mask = mask
60
+ self.attention_map = self.set_attention_map() if self.attention_init else None
61
+
62
+ self.thresh = self.set_attention_threshold_map() if self.attention_init else None
63
+ self.strokes_counter = 0 # counts the number of calls to "get_path"
64
+ self.epoch = 0
65
+ self.final_epoch = args.num_iter - 1
66
+
67
+
68
+ def init_image(self, stage=0):
69
+ if stage > 0:
70
+ # if multi stages training than add new strokes on existing ones
71
+ # don't optimize on previous strokes
72
+ self.optimize_flag = [False for i in range(len(self.shapes))]
73
+ for i in range(self.strokes_per_stage):
74
+ stroke_color = torch.tensor([0.0, 0.0, 0.0, 1.0])
75
+ path = self.get_path()
76
+ self.shapes.append(path)
77
+ path_group = pydiffvg.ShapeGroup(shape_ids = torch.tensor([len(self.shapes) - 1]),
78
+ fill_color = None,
79
+ stroke_color = stroke_color)
80
+ self.shape_groups.append(path_group)
81
+ self.optimize_flag.append(True)
82
+
83
+ else:
84
+ num_paths_exists = 0
85
+ if self.path_svg != "none":
86
+ self.canvas_width, self.canvas_height, self.shapes, self.shape_groups = utils.load_svg(self.path_svg)
87
+ # if you want to add more strokes to existing ones and optimize on all of them
88
+ num_paths_exists = len(self.shapes)
89
+
90
+ for i in range(num_paths_exists, self.num_paths):
91
+ stroke_color = torch.tensor([0.0, 0.0, 0.0, 1.0])
92
+ path = self.get_path()
93
+ self.shapes.append(path)
94
+ path_group = pydiffvg.ShapeGroup(shape_ids = torch.tensor([len(self.shapes) - 1]),
95
+ fill_color = None,
96
+ stroke_color = stroke_color)
97
+ self.shape_groups.append(path_group)
98
+ self.optimize_flag = [True for i in range(len(self.shapes))]
99
+
100
+ img = self.render_warp()
101
+ img = img[:, :, 3:4] * img[:, :, :3] + torch.ones(img.shape[0], img.shape[1], 3, device = self.device) * (1 - img[:, :, 3:4])
102
+ img = img[:, :, :3]
103
+ # Convert img from HWC to NCHW
104
+ img = img.unsqueeze(0)
105
+ img = img.permute(0, 3, 1, 2).to(self.device) # NHWC -> NCHW
106
+ return img
107
+ # utils.imwrite(img.cpu(), '{}/init.png'.format(args.output_dir), gamma=args.gamma, use_wandb=args.use_wandb, wandb_name="init")
108
+
109
+ def get_image(self):
110
+ img = self.render_warp()
111
+ opacity = img[:, :, 3:4]
112
+ img = opacity * img[:, :, :3] + torch.ones(img.shape[0], img.shape[1], 3, device = self.device) * (1 - opacity)
113
+ img = img[:, :, :3]
114
+ # Convert img from HWC to NCHW
115
+ img = img.unsqueeze(0)
116
+ img = img.permute(0, 3, 1, 2).to(self.device) # NHWC -> NCHW
117
+ return img
118
+
119
+ def get_path(self):
120
+ points = []
121
+ self.num_control_points = torch.zeros(self.num_segments, dtype = torch.int32) + (self.control_points_per_seg - 2)
122
+ p0 = self.inds_normalised[self.strokes_counter] if self.attention_init else (random.random(), random.random())
123
+ points.append(p0)
124
+
125
+ for j in range(self.num_segments):
126
+ radius = 0.05
127
+ for k in range(self.control_points_per_seg - 1):
128
+ p1 = (p0[0] + radius * (random.random() - 0.5), p0[1] + radius * (random.random() - 0.5))
129
+ points.append(p1)
130
+ p0 = p1
131
+ points = torch.tensor(points).to(self.device)
132
+ points[:, 0] *= self.canvas_width
133
+ points[:, 1] *= self.canvas_height
134
+
135
+ path = pydiffvg.Path(num_control_points = self.num_control_points,
136
+ points = points,
137
+ stroke_width = torch.tensor(self.width),
138
+ is_closed = False)
139
+ self.strokes_counter += 1
140
+ return path
141
+
142
+ def render_warp(self):
143
+ if self.opacity_optim:
144
+ for group in self.shape_groups:
145
+ group.stroke_color.data[:3].clamp_(0., 0.) # to force black stroke
146
+ group.stroke_color.data[-1].clamp_(0., 1.) # opacity
147
+ # group.stroke_color.data[-1] = (group.stroke_color.data[-1] >= self.color_vars_threshold).float()
148
+ _render = pydiffvg.RenderFunction.apply
149
+ # uncomment if you want to add random noise
150
+ if self.add_random_noise:
151
+ if random.random() > self.noise_thresh:
152
+ eps = 0.01 * min(self.canvas_width, self.canvas_height)
153
+ for path in self.shapes:
154
+ path.points.data.add_(eps * torch.randn_like(path.points))
155
+ scene_args = pydiffvg.RenderFunction.serialize_scene(\
156
+ self.canvas_width, self.canvas_height, self.shapes, self.shape_groups)
157
+ img = _render(self.canvas_width, # width
158
+ self.canvas_height, # height
159
+ 2, # num_samples_x
160
+ 2, # num_samples_y
161
+ 0, # seed
162
+ None,
163
+ *scene_args)
164
+ return img
165
+
166
+ def parameters(self):
167
+ self.points_vars = []
168
+ # storkes' location optimization
169
+ for i, path in enumerate(self.shapes):
170
+ if self.optimize_flag[i]:
171
+ path.points.requires_grad = True
172
+ self.points_vars.append(path.points)
173
+ return self.points_vars
174
+
175
+ def get_points_parans(self):
176
+ return self.points_vars
177
+
178
+ def set_color_parameters(self):
179
+ # for storkes' color optimization (opacity)
180
+ self.color_vars = []
181
+ for i, group in enumerate(self.shape_groups):
182
+ if self.optimize_flag[i]:
183
+ group.stroke_color.requires_grad = True
184
+ self.color_vars.append(group.stroke_color)
185
+ return self.color_vars
186
+
187
+ def get_color_parameters(self):
188
+ return self.color_vars
189
+
190
+ def save_svg(self, output_dir, name):
191
+ pydiffvg.save_svg('{}/{}.svg'.format(output_dir, name), self.canvas_width, self.canvas_height, self.shapes, self.shape_groups)
192
+
193
+
194
+ def dino_attn(self):
195
+ patch_size=8 # dino hyperparameter
196
+ threshold=0.6
197
+
198
+ # for dino model
199
+ mean_imagenet = torch.Tensor([0.485, 0.456, 0.406])[None,:,None,None].to(self.device)
200
+ std_imagenet = torch.Tensor([0.229, 0.224, 0.225])[None,:,None,None].to(self.device)
201
+ totens = transforms.Compose([
202
+ transforms.Resize((self.canvas_height, self.canvas_width)),
203
+ transforms.ToTensor()
204
+ ])
205
+
206
+ dino_model = torch.hub.load('facebookresearch/dino:main', 'dino_vits8').eval().to(self.device)
207
+
208
+ self.main_im = Image.open(self.target_path).convert("RGB")
209
+ main_im_tensor = totens(self.main_im).to(self.device)
210
+ img = (main_im_tensor.unsqueeze(0) - mean_imagenet) / std_imagenet
211
+ w_featmap = img.shape[-2] // patch_size
212
+ h_featmap = img.shape[-1] // patch_size
213
+
214
+ with torch.no_grad():
215
+ attn = dino_model.get_last_selfattention(img).detach().cpu()[0]
216
+
217
+ nh = attn.shape[0]
218
+ attn = attn[:,0,1:].reshape(nh,-1)
219
+ val, idx = torch.sort(attn)
220
+ val /= torch.sum(val, dim=1, keepdim=True)
221
+ cumval = torch.cumsum(val, dim=1)
222
+ th_attn = cumval > (1 - threshold)
223
+ idx2 = torch.argsort(idx)
224
+ for head in range(nh):
225
+ th_attn[head] = th_attn[head][idx2[head]]
226
+ th_attn = th_attn.reshape(nh, w_featmap, h_featmap).float()
227
+ th_attn = nn.functional.interpolate(th_attn.unsqueeze(0), scale_factor=patch_size, mode="nearest")[0].cpu()
228
+
229
+ attn = attn.reshape(nh, w_featmap, h_featmap).float()
230
+ attn = nn.functional.interpolate(attn.unsqueeze(0), scale_factor=patch_size, mode="nearest")[0].cpu()
231
+
232
+ return attn
233
+
234
+
235
+ def define_attention_input(self, target_im):
236
+ model, preprocess = clip.load(self.saliency_clip_model, device=self.device, jit=False)
237
+ model.eval().to(self.device)
238
+ data_transforms = transforms.Compose([
239
+ preprocess.transforms[-1],
240
+ ])
241
+ self.image_input_attn_clip = data_transforms(target_im).to(self.device)
242
+
243
+
244
+ def clip_attn(self):
245
+ model, preprocess = clip.load(self.saliency_clip_model, device=self.device, jit=False)
246
+ model.eval().to(self.device)
247
+ text_input = clip.tokenize([self.text_target]).to(self.device)
248
+
249
+ if "RN" in self.saliency_clip_model:
250
+ saliency_layer = "layer4"
251
+ attn_map = gradCAM(
252
+ model.visual,
253
+ self.image_input_attn_clip,
254
+ model.encode_text(text_input).float(),
255
+ getattr(model.visual, saliency_layer)
256
+ )
257
+ attn_map = attn_map.squeeze().detach().cpu().numpy()
258
+ attn_map = (attn_map - attn_map.min()) / (attn_map.max() - attn_map.min())
259
+
260
+ else:
261
+ # attn_map = interpret(self.image_input_attn_clip, text_input, model, device=self.device, index=0).astype(np.float32)
262
+ attn_map = interpret(self.image_input_attn_clip, text_input, model, device=self.device)
263
+
264
+ del model
265
+ return attn_map
266
+
267
+ def set_attention_map(self):
268
+ assert self.saliency_model in ["dino", "clip"]
269
+ if self.saliency_model == "dino":
270
+ return self.dino_attn()
271
+ elif self.saliency_model == "clip":
272
+ return self.clip_attn()
273
+
274
+
275
+ def softmax(self, x, tau=0.2):
276
+ e_x = np.exp(x / tau)
277
+ return e_x / e_x.sum()
278
+
279
+ def set_inds_clip(self):
280
+ attn_map = (self.attention_map - self.attention_map.min()) / (self.attention_map.max() - self.attention_map.min())
281
+ if self.xdog_intersec:
282
+ xdog = XDoG_()
283
+ im_xdog = xdog(self.image_input_attn_clip[0].permute(1,2,0).cpu().numpy(), k=10)
284
+ intersec_map = (1 - im_xdog) * attn_map
285
+ attn_map = intersec_map
286
+
287
+ attn_map_soft = np.copy(attn_map)
288
+ attn_map_soft[attn_map > 0] = self.softmax(attn_map[attn_map > 0], tau=self.softmax_temp)
289
+
290
+ k = self.num_stages * self.num_paths
291
+ self.inds = np.random.choice(range(attn_map.flatten().shape[0]), size=k, replace=False, p=attn_map_soft.flatten())
292
+ self.inds = np.array(np.unravel_index(self.inds, attn_map.shape)).T
293
+
294
+ self.inds_normalised = np.zeros(self.inds.shape)
295
+ self.inds_normalised[:, 0] = self.inds[:, 1] / self.canvas_width
296
+ self.inds_normalised[:, 1] = self.inds[:, 0] / self.canvas_height
297
+ self.inds_normalised = self.inds_normalised.tolist()
298
+ return attn_map_soft
299
+
300
+
301
+
302
+ def set_inds_dino(self):
303
+ k = max(3, (self.num_stages * self.num_paths) // 6 + 1) # sample top 3 three points from each attention head
304
+ num_heads = self.attention_map.shape[0]
305
+ self.inds = np.zeros((k * num_heads, 2))
306
+ # "thresh" is used for visualisaiton purposes only
307
+ thresh = torch.zeros(num_heads + 1, self.attention_map.shape[1], self.attention_map.shape[2])
308
+ softmax = nn.Softmax(dim=1)
309
+ for i in range(num_heads):
310
+ # replace "self.attention_map[i]" with "self.attention_map" to get the highest values among
311
+ # all heads.
312
+ topk, indices = np.unique(self.attention_map[i].numpy(), return_index=True)
313
+ topk = topk[::-1][:k]
314
+ cur_attn_map = self.attention_map[i].numpy()
315
+ # prob function for uniform sampling
316
+ prob = cur_attn_map.flatten()
317
+ prob[prob > topk[-1]] = 1
318
+ prob[prob <= topk[-1]] = 0
319
+ prob = prob / prob.sum()
320
+ thresh[i] = torch.Tensor(prob.reshape(cur_attn_map.shape))
321
+
322
+ # choose k pixels from each head
323
+ inds = np.random.choice(range(cur_attn_map.flatten().shape[0]), size=k, replace=False, p=prob)
324
+ inds = np.unravel_index(inds, cur_attn_map.shape)
325
+ self.inds[i * k: i * k + k, 0] = inds[0]
326
+ self.inds[i * k: i * k + k, 1] = inds[1]
327
+
328
+ # for visualisaiton
329
+ sum_attn = self.attention_map.sum(0).numpy()
330
+ mask = np.zeros(sum_attn.shape)
331
+ mask[thresh[:-1].sum(0) > 0] = 1
332
+ sum_attn = sum_attn * mask
333
+ sum_attn = sum_attn / sum_attn.sum()
334
+ thresh[-1] = torch.Tensor(sum_attn)
335
+
336
+ # sample num_paths from the chosen pixels.
337
+ prob_sum = sum_attn[self.inds[:,0].astype(np.int), self.inds[:,1].astype(np.int)]
338
+ prob_sum = prob_sum / prob_sum.sum()
339
+ new_inds = []
340
+ for i in range(self.num_stages):
341
+ new_inds.extend(np.random.choice(range(self.inds.shape[0]), size=self.num_paths, replace=False, p=prob_sum))
342
+ self.inds = self.inds[new_inds]
343
+ print("self.inds",self.inds.shape)
344
+
345
+ self.inds_normalised = np.zeros(self.inds.shape)
346
+ self.inds_normalised[:, 0] = self.inds[:, 1] / self.canvas_width
347
+ self.inds_normalised[:, 1] = self.inds[:, 0] / self.canvas_height
348
+ self.inds_normalised = self.inds_normalised.tolist()
349
+ return thresh
350
+
351
+ def set_attention_threshold_map(self):
352
+ assert self.saliency_model in ["dino", "clip"]
353
+ if self.saliency_model == "dino":
354
+ return self.set_inds_dino()
355
+ elif self.saliency_model == "clip":
356
+ return self.set_inds_clip()
357
+
358
+
359
+ def get_attn(self):
360
+ return self.attention_map
361
+
362
+ def get_thresh(self):
363
+ return self.thresh
364
+
365
+ def get_inds(self):
366
+ return self.inds
367
+
368
+ def get_mask(self):
369
+ return self.mask
370
+
371
+ def set_random_noise(self, epoch):
372
+ if epoch % self.args.save_interval == 0:
373
+ self.add_random_noise = False
374
+ else:
375
+ self.add_random_noise = "noise" in self.args.augemntations
376
+
377
+ class PainterOptimizer:
378
+ def __init__(self, args, renderer):
379
+ self.renderer = renderer
380
+ self.points_lr = args.lr
381
+ self.color_lr = args.color_lr
382
+ self.args = args
383
+ self.optim_color = args.force_sparse
384
+
385
+ def init_optimizers(self):
386
+ self.points_optim = torch.optim.Adam(self.renderer.parameters(), lr=self.points_lr)
387
+ if self.optim_color:
388
+ self.color_optim = torch.optim.Adam(self.renderer.set_color_parameters(), lr=self.color_lr)
389
+
390
+ def update_lr(self, counter):
391
+ new_lr = utils.get_epoch_lr(counter, self.args)
392
+ for param_group in self.points_optim.param_groups:
393
+ param_group["lr"] = new_lr
394
+
395
+ def zero_grad_(self):
396
+ self.points_optim.zero_grad()
397
+ if self.optim_color:
398
+ self.color_optim.zero_grad()
399
+
400
+ def step_(self):
401
+ self.points_optim.step()
402
+ if self.optim_color:
403
+ self.color_optim.step()
404
+
405
+ def get_lr(self):
406
+ return self.points_optim.param_groups[0]['lr']
407
+
408
+
409
+ class Hook:
410
+ """Attaches to a module and records its activations and gradients."""
411
+
412
+ def __init__(self, module: nn.Module):
413
+ self.data = None
414
+ self.hook = module.register_forward_hook(self.save_grad)
415
+
416
+ def save_grad(self, module, input, output):
417
+ self.data = output
418
+ output.requires_grad_(True)
419
+ output.retain_grad()
420
+
421
+ def __enter__(self):
422
+ return self
423
+
424
+ def __exit__(self, exc_type, exc_value, exc_traceback):
425
+ self.hook.remove()
426
+
427
+ @property
428
+ def activation(self) -> torch.Tensor:
429
+ return self.data
430
+
431
+ @property
432
+ def gradient(self) -> torch.Tensor:
433
+ return self.data.grad
434
+
435
+
436
+
437
+
438
+ def interpret(image, texts, model, device):
439
+ images = image.repeat(1, 1, 1, 1)
440
+ res = model.encode_image(images)
441
+ model.zero_grad()
442
+ image_attn_blocks = list(dict(model.visual.transformer.resblocks.named_children()).values())
443
+ num_tokens = image_attn_blocks[0].attn_probs.shape[-1]
444
+ R = torch.eye(num_tokens, num_tokens, dtype=image_attn_blocks[0].attn_probs.dtype).to(device)
445
+ R = R.unsqueeze(0).expand(1, num_tokens, num_tokens)
446
+ cams = [] # there are 12 attention blocks
447
+ for i, blk in enumerate(image_attn_blocks):
448
+ cam = blk.attn_probs.detach() #attn_probs shape is 12, 50, 50
449
+ # each patch is 7x7 so we have 49 pixels + 1 for positional encoding
450
+ cam = cam.reshape(1, -1, cam.shape[-1], cam.shape[-1])
451
+ cam = cam.clamp(min=0)
452
+ cam = cam.clamp(min=0).mean(dim=1) # mean of the 12 something
453
+ cams.append(cam)
454
+ R = R + torch.bmm(cam, R)
455
+
456
+ cams_avg = torch.cat(cams) # 12, 50, 50
457
+ cams_avg = cams_avg[:, 0, 1:] # 12, 1, 49
458
+ image_relevance = cams_avg.mean(dim=0).unsqueeze(0)
459
+ image_relevance = image_relevance.reshape(1, 1, 7, 7)
460
+ image_relevance = torch.nn.functional.interpolate(image_relevance, size=224, mode='bicubic')
461
+ image_relevance = image_relevance.reshape(224, 224).data.cpu().numpy().astype(np.float32)
462
+ image_relevance = (image_relevance - image_relevance.min()) / (image_relevance.max() - image_relevance.min())
463
+ return image_relevance
464
+
465
+
466
+ # Reference: https://arxiv.org/abs/1610.02391
467
+ def gradCAM(
468
+ model: nn.Module,
469
+ input: torch.Tensor,
470
+ target: torch.Tensor,
471
+ layer: nn.Module
472
+ ) -> torch.Tensor:
473
+ # Zero out any gradients at the input.
474
+ if input.grad is not None:
475
+ input.grad.data.zero_()
476
+
477
+ # Disable gradient settings.
478
+ requires_grad = {}
479
+ for name, param in model.named_parameters():
480
+ requires_grad[name] = param.requires_grad
481
+ param.requires_grad_(False)
482
+
483
+ # Attach a hook to the model at the desired layer.
484
+ assert isinstance(layer, nn.Module)
485
+ with Hook(layer) as hook:
486
+ # Do a forward and backward pass.
487
+ output = model(input)
488
+ output.backward(target)
489
+
490
+ grad = hook.gradient.float()
491
+ act = hook.activation.float()
492
+
493
+ # Global average pool gradient across spatial dimension
494
+ # to obtain importance weights.
495
+ alpha = grad.mean(dim=(2, 3), keepdim=True)
496
+ # Weighted combination of activation maps over channel
497
+ # dimension.
498
+ gradcam = torch.sum(act * alpha, dim=1, keepdim=True)
499
+ # We only want neurons with positive influence so we
500
+ # clamp any negative ones.
501
+ gradcam = torch.clamp(gradcam, min=0)
502
+
503
+ # Resize gradcam to input resolution.
504
+ gradcam = F.interpolate(
505
+ gradcam,
506
+ input.shape[2:],
507
+ mode='bicubic',
508
+ align_corners=False)
509
+
510
+ # Restore gradient settings.
511
+ for name, param in model.named_parameters():
512
+ param.requires_grad_(requires_grad[name])
513
+
514
+ return gradcam
515
+
516
+
517
+ class XDoG_(object):
518
+ def __init__(self):
519
+ super(XDoG_, self).__init__()
520
+ self.gamma=0.98
521
+ self.phi=200
522
+ self.eps=-0.1
523
+ self.sigma=0.8
524
+ self.binarize=True
525
+
526
+ def __call__(self, im, k=10):
527
+ if im.shape[2] == 3:
528
+ im = rgb2gray(im)
529
+ imf1 = gaussian_filter(im, self.sigma)
530
+ imf2 = gaussian_filter(im, self.sigma * k)
531
+ imdiff = imf1 - self.gamma * imf2
532
+ imdiff = (imdiff < self.eps) * 1.0 + (imdiff >= self.eps) * (1.0 + np.tanh(self.phi * imdiff))
533
+ imdiff -= imdiff.min()
534
+ imdiff /= imdiff.max()
535
+ if self.binarize:
536
+ th = threshold_otsu(imdiff)
537
+ imdiff = imdiff >= th
538
+ imdiff = imdiff.astype('float32')
539
+ return imdiff
painterly_rendering.py ADDED
@@ -0,0 +1,190 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import warnings
2
+
3
+ warnings.filterwarnings('ignore')
4
+ warnings.simplefilter('ignore')
5
+
6
+ import argparse
7
+ import math
8
+ import os
9
+ import sys
10
+ import time
11
+ import traceback
12
+
13
+ import numpy as np
14
+ import PIL
15
+ import torch
16
+ import torch.nn as nn
17
+ import torch.nn.functional as F
18
+ import wandb
19
+ from PIL import Image
20
+ from torchvision import models, transforms
21
+ from tqdm.auto import tqdm, trange
22
+
23
+ import config
24
+ import sketch_utils as utils
25
+ from models.loss import Loss
26
+ from models.painter_params import Painter, PainterOptimizer
27
+ from IPython.display import display, SVG
28
+
29
+
30
+ def load_renderer(args, target_im=None, mask=None):
31
+ renderer = Painter(num_strokes=args.num_paths, args=args,
32
+ num_segments=args.num_segments,
33
+ imsize=args.image_scale,
34
+ device=args.device,
35
+ target_im=target_im,
36
+ mask=mask)
37
+ renderer = renderer.to(args.device)
38
+ return renderer
39
+
40
+
41
+ def get_target(args):
42
+ target = Image.open(args.target)
43
+ if target.mode == "RGBA":
44
+ # Create a white rgba background
45
+ new_image = Image.new("RGBA", target.size, "WHITE")
46
+ # Paste the image on the background.
47
+ new_image.paste(target, (0, 0), target)
48
+ target = new_image
49
+ target = target.convert("RGB")
50
+ masked_im, mask = utils.get_mask_u2net(args, target)
51
+ if args.mask_object:
52
+ target = masked_im
53
+ if args.fix_scale:
54
+ target = utils.fix_image_scale(target)
55
+
56
+ transforms_ = []
57
+ if target.size[0] != target.size[1]:
58
+ transforms_.append(transforms.Resize(
59
+ (args.image_scale, args.image_scale), interpolation=PIL.Image.BICUBIC))
60
+ else:
61
+ transforms_.append(transforms.Resize(
62
+ args.image_scale, interpolation=PIL.Image.BICUBIC))
63
+ transforms_.append(transforms.CenterCrop(args.image_scale))
64
+ transforms_.append(transforms.ToTensor())
65
+ data_transforms = transforms.Compose(transforms_)
66
+ target_ = data_transforms(target).unsqueeze(0).to(args.device)
67
+ return target_, mask
68
+
69
+
70
+ def main(args):
71
+ loss_func = Loss(args)
72
+ inputs, mask = get_target(args)
73
+ utils.log_input(args.use_wandb, 0, inputs, args.output_dir)
74
+ renderer = load_renderer(args, inputs, mask)
75
+
76
+ optimizer = PainterOptimizer(args, renderer)
77
+ counter = 0
78
+ configs_to_save = {"loss_eval": []}
79
+ best_loss, best_fc_loss = 100, 100
80
+ best_iter, best_iter_fc = 0, 0
81
+ min_delta = 1e-5
82
+ terminate = False
83
+
84
+ renderer.set_random_noise(0)
85
+ img = renderer.init_image(stage=0)
86
+ optimizer.init_optimizers()
87
+
88
+ # not using tdqm for jupyter demo
89
+ if args.display:
90
+ epoch_range = range(args.num_iter)
91
+ else:
92
+ epoch_range = tqdm(range(args.num_iter))
93
+
94
+ for epoch in epoch_range:
95
+ if not args.display:
96
+ epoch_range.refresh()
97
+ renderer.set_random_noise(epoch)
98
+ if args.lr_scheduler:
99
+ optimizer.update_lr(counter)
100
+
101
+ start = time.time()
102
+ optimizer.zero_grad_()
103
+ sketches = renderer.get_image().to(args.device)
104
+ losses_dict = loss_func(sketches, inputs.detach(
105
+ ), renderer.get_color_parameters(), renderer, counter, optimizer)
106
+ loss = sum(list(losses_dict.values()))
107
+ loss.backward()
108
+ optimizer.step_()
109
+ if epoch % args.save_interval == 0:
110
+ utils.plot_batch(inputs, sketches, f"{args.output_dir}/jpg_logs", counter,
111
+ use_wandb=args.use_wandb, title=f"iter{epoch}.jpg")
112
+ renderer.save_svg(
113
+ f"{args.output_dir}/svg_logs", f"svg_iter{epoch}")
114
+ if epoch % args.eval_interval == 0:
115
+ with torch.no_grad():
116
+ losses_dict_eval = loss_func(sketches, inputs, renderer.get_color_parameters(
117
+ ), renderer.get_points_parans(), counter, optimizer, mode="eval")
118
+ loss_eval = sum(list(losses_dict_eval.values()))
119
+ configs_to_save["loss_eval"].append(loss_eval.item())
120
+ for k in losses_dict_eval.keys():
121
+ if k not in configs_to_save.keys():
122
+ configs_to_save[k] = []
123
+ configs_to_save[k].append(losses_dict_eval[k].item())
124
+ if args.clip_fc_loss_weight:
125
+ if losses_dict_eval["fc"].item() < best_fc_loss:
126
+ best_fc_loss = losses_dict_eval["fc"].item(
127
+ ) / args.clip_fc_loss_weight
128
+ best_iter_fc = epoch
129
+ # print(
130
+ # f"eval iter[{epoch}/{args.num_iter}] loss[{loss.item()}] time[{time.time() - start}]")
131
+
132
+ cur_delta = loss_eval.item() - best_loss
133
+ if abs(cur_delta) > min_delta:
134
+ if cur_delta < 0:
135
+ best_loss = loss_eval.item()
136
+ best_iter = epoch
137
+ terminate = False
138
+ utils.plot_batch(
139
+ inputs, sketches, args.output_dir, counter, use_wandb=args.use_wandb, title="best_iter.jpg")
140
+ renderer.save_svg(args.output_dir, "best_iter")
141
+
142
+ if args.use_wandb:
143
+ wandb.run.summary["best_loss"] = best_loss
144
+ wandb.run.summary["best_loss_fc"] = best_fc_loss
145
+ wandb_dict = {"delta": cur_delta,
146
+ "loss_eval": loss_eval.item()}
147
+ for k in losses_dict_eval.keys():
148
+ wandb_dict[k + "_eval"] = losses_dict_eval[k].item()
149
+ wandb.log(wandb_dict, step=counter)
150
+
151
+ if abs(cur_delta) <= min_delta:
152
+ if terminate:
153
+ break
154
+ terminate = True
155
+
156
+ if counter == 0 and args.attention_init:
157
+ utils.plot_atten(renderer.get_attn(), renderer.get_thresh(), inputs, renderer.get_inds(),
158
+ args.use_wandb, "{}/{}.jpg".format(
159
+ args.output_dir, "attention_map"),
160
+ args.saliency_model, args.display_logs)
161
+
162
+ if args.use_wandb:
163
+ wandb_dict = {"loss": loss.item(), "lr": optimizer.get_lr()}
164
+ for k in losses_dict.keys():
165
+ wandb_dict[k] = losses_dict[k].item()
166
+ wandb.log(wandb_dict, step=counter)
167
+
168
+ counter += 1
169
+
170
+ renderer.save_svg(args.output_dir, "final_svg")
171
+ path_svg = os.path.join(args.output_dir, "best_iter.svg")
172
+ utils.log_sketch_summary_final(
173
+ path_svg, args.use_wandb, args.device, best_iter, best_loss, "best total")
174
+
175
+ return configs_to_save
176
+
177
+ if __name__ == "__main__":
178
+ args = config.parse_arguments()
179
+ final_config = vars(args)
180
+ try:
181
+ configs_to_save = main(args)
182
+ except BaseException as err:
183
+ print(f"Unexpected error occurred:\n {err}")
184
+ print(traceback.format_exc())
185
+ sys.exit(1)
186
+ for k in configs_to_save.keys():
187
+ final_config[k] = configs_to_save[k]
188
+ np.save(f"{args.output_dir}/config.npy", final_config)
189
+ if args.use_wandb:
190
+ wandb.finish()
predict.py ADDED
@@ -0,0 +1,317 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # sudo cog push r8.im/yael-vinker/clipasso
2
+
3
+ # Prediction interface for Cog ⚙️
4
+ # https://github.com/replicate/cog/blob/main/docs/python.md
5
+
6
+ import warnings
7
+
8
+ warnings.filterwarnings('ignore')
9
+ warnings.simplefilter('ignore')
10
+
11
+ from cog import BasePredictor, Input, Path
12
+ import subprocess as sp
13
+ import os
14
+ import re
15
+
16
+ import imageio
17
+ import matplotlib.pyplot as plt
18
+ import numpy as np
19
+ import pydiffvg
20
+ import torch
21
+ from PIL import Image
22
+ import multiprocessing as mp
23
+ from shutil import copyfile
24
+
25
+ import argparse
26
+ import math
27
+ import sys
28
+ import time
29
+ import traceback
30
+
31
+ import PIL
32
+ import torch.nn as nn
33
+ import torch.nn.functional as F
34
+ import wandb
35
+ from torchvision import models, transforms
36
+ from tqdm import tqdm
37
+
38
+ import config
39
+ import sketch_utils as utils
40
+ from models.loss import Loss
41
+ from models.painter_params import Painter, PainterOptimizer
42
+
43
+
44
+ class Predictor(BasePredictor):
45
+ def setup(self):
46
+ """Load the model into memory to make running multiple predictions efficient"""
47
+ self.num_iter = 2001
48
+ self.save_interval = 100
49
+ self.num_sketches = 3
50
+ self.use_gpu = True
51
+
52
+ def predict(
53
+ self,
54
+ target_image: Path = Input(description="Input image (square, without background)"),
55
+ num_strokes: int = Input(description="The number of strokes used to create the sketch, which determines the level of abstraction",default=16),
56
+ trials: int = Input(description="It is recommended to use 3 trials to recieve the best sketch, but it might be slower",default=3),
57
+ mask_object: int = Input(description="It is recommended to use images without a background, however, if your image contains a background, you can mask it out by using this flag with 1 as an argument",default=0),
58
+ fix_scale: int = Input(description="If your image is not squared, it might be cut off, it is recommended to use this flag with 1 as input to automatically fix the scale without cutting the image",default=0),
59
+ ) -> Path:
60
+
61
+ self.num_sketches = trials
62
+ target_image_name = os.path.basename(str(target_image))
63
+
64
+ multiprocess = False
65
+ abs_path = os.path.abspath(os.getcwd())
66
+
67
+ target = str(target_image)
68
+ assert os.path.isfile(target), f"{target} does not exists!"
69
+
70
+ test_name = os.path.splitext(target_image_name)[0]
71
+ output_dir = f"{abs_path}/output_sketches/{test_name}/"
72
+ if not os.path.exists(output_dir):
73
+ os.makedirs(output_dir)
74
+
75
+ print("=" * 50)
76
+ print(f"Processing [{target_image_name}] ...")
77
+ print(f"Results will be saved to \n[{output_dir}] ...")
78
+ print("=" * 50)
79
+
80
+ if not torch.cuda.is_available():
81
+ self.use_gpu = False
82
+ print("CUDA is not configured with GPU, running with CPU instead.")
83
+ print("Note that this will be very slow, it is recommended to use colab.")
84
+ print(f"GPU: {self.use_gpu}")
85
+ seeds = list(range(0, self.num_sketches * 1000, 1000))
86
+
87
+ losses_all = {}
88
+
89
+ for seed in seeds:
90
+ wandb_name = f"{test_name}_{num_strokes}strokes_seed{seed}"
91
+ sp.run(["python", "config.py", target,
92
+ "--num_paths", str(num_strokes),
93
+ "--output_dir", output_dir,
94
+ "--wandb_name", wandb_name,
95
+ "--num_iter", str(self.num_iter),
96
+ "--save_interval", str(self.save_interval),
97
+ "--seed", str(seed),
98
+ "--use_gpu", str(int(self.use_gpu)),
99
+ "--fix_scale", str(fix_scale),
100
+ "--mask_object", str(mask_object),
101
+ "--mask_object_attention", str(
102
+ mask_object),
103
+ "--display_logs", str(int(0))])
104
+ config_init = np.load(f"{output_dir}/{wandb_name}/config_init.npy", allow_pickle=True)[()]
105
+ args = Args(config_init)
106
+ args.cog_display = True
107
+
108
+ final_config = vars(args)
109
+ try:
110
+ configs_to_save = main(args)
111
+ except BaseException as err:
112
+ print(f"Unexpected error occurred:\n {err}")
113
+ print(traceback.format_exc())
114
+ sys.exit(1)
115
+ for k in configs_to_save.keys():
116
+ final_config[k] = configs_to_save[k]
117
+ np.save(f"{args.output_dir}/config.npy", final_config)
118
+ if args.use_wandb:
119
+ wandb.finish()
120
+
121
+ config = np.load(f"{output_dir}/{wandb_name}/config.npy",
122
+ allow_pickle=True)[()]
123
+ loss_eval = np.array(config['loss_eval'])
124
+ inds = np.argsort(loss_eval)
125
+ losses_all[wandb_name] = loss_eval[inds][0]
126
+ # return Path(f"{output_dir}/{wandb_name}/best_iter.svg")
127
+
128
+
129
+ sorted_final = dict(sorted(losses_all.items(), key=lambda item: item[1]))
130
+ copyfile(f"{output_dir}/{list(sorted_final.keys())[0]}/best_iter.svg",
131
+ f"{output_dir}/{list(sorted_final.keys())[0]}_best.svg")
132
+ target_path = f"{abs_path}/target_images/{target_image_name}"
133
+ svg_files = os.listdir(output_dir)
134
+ svg_files = [f for f in svg_files if "best.svg" in f]
135
+ svg_output_path = f"{output_dir}/{svg_files[0]}"
136
+ sketch_res = read_svg(svg_output_path, multiply=True).cpu().numpy()
137
+ sketch_res = Image.fromarray((sketch_res * 255).astype('uint8'), 'RGB')
138
+ sketch_res.save(f"{abs_path}/output_sketches/sketch.png")
139
+ return Path(svg_output_path)
140
+
141
+
142
+ class Args():
143
+ def __init__(self, config):
144
+ for k in config.keys():
145
+ setattr(self, k, config[k])
146
+
147
+
148
+ def load_renderer(args, target_im=None, mask=None):
149
+ renderer = Painter(num_strokes=args.num_paths, args=args,
150
+ num_segments=args.num_segments,
151
+ imsize=args.image_scale,
152
+ device=args.device,
153
+ target_im=target_im,
154
+ mask=mask)
155
+ renderer = renderer.to(args.device)
156
+ return renderer
157
+
158
+
159
+ def get_target(args):
160
+ target = Image.open(args.target)
161
+ if target.mode == "RGBA":
162
+ # Create a white rgba background
163
+ new_image = Image.new("RGBA", target.size, "WHITE")
164
+ # Paste the image on the background.
165
+ new_image.paste(target, (0, 0), target)
166
+ target = new_image
167
+ target = target.convert("RGB")
168
+ masked_im, mask = utils.get_mask_u2net(args, target)
169
+ if args.mask_object:
170
+ target = masked_im
171
+ if args.fix_scale:
172
+ target = utils.fix_image_scale(target)
173
+
174
+ transforms_ = []
175
+ if target.size[0] != target.size[1]:
176
+ transforms_.append(transforms.Resize(
177
+ (args.image_scale, args.image_scale), interpolation=PIL.Image.BICUBIC))
178
+ else:
179
+ transforms_.append(transforms.Resize(
180
+ args.image_scale, interpolation=PIL.Image.BICUBIC))
181
+ transforms_.append(transforms.CenterCrop(args.image_scale))
182
+ transforms_.append(transforms.ToTensor())
183
+ data_transforms = transforms.Compose(transforms_)
184
+ target_ = data_transforms(target).unsqueeze(0).to(args.device)
185
+ return target_, mask
186
+
187
+
188
+ def main(args):
189
+ loss_func = Loss(args)
190
+ inputs, mask = get_target(args)
191
+ utils.log_input(args.use_wandb, 0, inputs, args.output_dir)
192
+ renderer = load_renderer(args, inputs, mask)
193
+
194
+ optimizer = PainterOptimizer(args, renderer)
195
+ counter = 0
196
+ configs_to_save = {"loss_eval": []}
197
+ best_loss, best_fc_loss = 100, 100
198
+ best_iter, best_iter_fc = 0, 0
199
+ min_delta = 1e-5
200
+ terminate = False
201
+
202
+ renderer.set_random_noise(0)
203
+ img = renderer.init_image(stage=0)
204
+ optimizer.init_optimizers()
205
+
206
+ for epoch in tqdm(range(args.num_iter)):
207
+ renderer.set_random_noise(epoch)
208
+ if args.lr_scheduler:
209
+ optimizer.update_lr(counter)
210
+
211
+ start = time.time()
212
+ optimizer.zero_grad_()
213
+ sketches = renderer.get_image().to(args.device)
214
+ losses_dict = loss_func(sketches, inputs.detach(
215
+ ), renderer.get_color_parameters(), renderer, counter, optimizer)
216
+ loss = sum(list(losses_dict.values()))
217
+ loss.backward()
218
+ optimizer.step_()
219
+ if epoch % args.save_interval == 0:
220
+ utils.plot_batch(inputs, sketches, f"{args.output_dir}/jpg_logs", counter,
221
+ use_wandb=args.use_wandb, title=f"iter{epoch}.jpg")
222
+ renderer.save_svg(
223
+ f"{args.output_dir}/svg_logs", f"svg_iter{epoch}")
224
+ # if args.cog_display:
225
+ # yield Path(f"{args.output_dir}/svg_logs/svg_iter{epoch}.svg")
226
+
227
+
228
+ if epoch % args.eval_interval == 0:
229
+ with torch.no_grad():
230
+ losses_dict_eval = loss_func(sketches, inputs, renderer.get_color_parameters(
231
+ ), renderer.get_points_parans(), counter, optimizer, mode="eval")
232
+ loss_eval = sum(list(losses_dict_eval.values()))
233
+ configs_to_save["loss_eval"].append(loss_eval.item())
234
+ for k in losses_dict_eval.keys():
235
+ if k not in configs_to_save.keys():
236
+ configs_to_save[k] = []
237
+ configs_to_save[k].append(losses_dict_eval[k].item())
238
+ if args.clip_fc_loss_weight:
239
+ if losses_dict_eval["fc"].item() < best_fc_loss:
240
+ best_fc_loss = losses_dict_eval["fc"].item(
241
+ ) / args.clip_fc_loss_weight
242
+ best_iter_fc = epoch
243
+ # print(
244
+ # f"eval iter[{epoch}/{args.num_iter}] loss[{loss.item()}] time[{time.time() - start}]")
245
+
246
+ cur_delta = loss_eval.item() - best_loss
247
+ if abs(cur_delta) > min_delta:
248
+ if cur_delta < 0:
249
+ best_loss = loss_eval.item()
250
+ best_iter = epoch
251
+ terminate = False
252
+ utils.plot_batch(
253
+ inputs, sketches, args.output_dir, counter, use_wandb=args.use_wandb, title="best_iter.jpg")
254
+ renderer.save_svg(args.output_dir, "best_iter")
255
+
256
+ if args.use_wandb:
257
+ wandb.run.summary["best_loss"] = best_loss
258
+ wandb.run.summary["best_loss_fc"] = best_fc_loss
259
+ wandb_dict = {"delta": cur_delta,
260
+ "loss_eval": loss_eval.item()}
261
+ for k in losses_dict_eval.keys():
262
+ wandb_dict[k + "_eval"] = losses_dict_eval[k].item()
263
+ wandb.log(wandb_dict, step=counter)
264
+
265
+ if abs(cur_delta) <= min_delta:
266
+ if terminate:
267
+ break
268
+ terminate = True
269
+
270
+ if counter == 0 and args.attention_init:
271
+ utils.plot_atten(renderer.get_attn(), renderer.get_thresh(), inputs, renderer.get_inds(),
272
+ args.use_wandb, "{}/{}.jpg".format(
273
+ args.output_dir, "attention_map"),
274
+ args.saliency_model, args.display_logs)
275
+
276
+ if args.use_wandb:
277
+ wandb_dict = {"loss": loss.item(), "lr": optimizer.get_lr()}
278
+ for k in losses_dict.keys():
279
+ wandb_dict[k] = losses_dict[k].item()
280
+ wandb.log(wandb_dict, step=counter)
281
+
282
+ counter += 1
283
+
284
+ renderer.save_svg(args.output_dir, "final_svg")
285
+ path_svg = os.path.join(args.output_dir, "best_iter.svg")
286
+ utils.log_sketch_summary_final(
287
+ path_svg, args.use_wandb, args.device, best_iter, best_loss, "best total")
288
+
289
+ return configs_to_save
290
+
291
+
292
+ def read_svg(path_svg, multiply=False):
293
+ device = torch.device("cuda" if (
294
+ torch.cuda.is_available() and torch.cuda.device_count() > 0) else "cpu")
295
+ canvas_width, canvas_height, shapes, shape_groups = pydiffvg.svg_to_scene(
296
+ path_svg)
297
+ if multiply:
298
+ canvas_width *= 2
299
+ canvas_height *= 2
300
+ for path in shapes:
301
+ path.points *= 2
302
+ path.stroke_width *= 2
303
+ _render = pydiffvg.RenderFunction.apply
304
+ scene_args = pydiffvg.RenderFunction.serialize_scene(
305
+ canvas_width, canvas_height, shapes, shape_groups)
306
+ img = _render(canvas_width, # width
307
+ canvas_height, # height
308
+ 2, # num_samples_x
309
+ 2, # num_samples_y
310
+ 0, # seed
311
+ None,
312
+ *scene_args)
313
+ img = img[:, :, 3:4] * img[:, :, :3] + \
314
+ torch.ones(img.shape[0], img.shape[1], 3,
315
+ device=device) * (1 - img[:, :, 3:4])
316
+ img = img[:, :, :3]
317
+ return img
requirements.txt ADDED
@@ -0,0 +1,53 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ftfy==6.0.3
2
+ cssutils==2.3.0
3
+ gdown==4.4.0
4
+ imageio==2.9.0
5
+ imageio-ffmpeg==0.4.4
6
+ importlib-metadata==4.6.4
7
+ ipykernel==6.1.0
8
+ ipython==7.26.0
9
+ ipython-genutils==0.2.0
10
+ json5==0.9.5
11
+ jsonpatch==1.32
12
+ jsonpointer==2.1
13
+ jsonschema==3.2.0
14
+ jupyter-client==6.1.12
15
+ jupyter-core==4.7.1
16
+ jupyter-server==1.10.2
17
+ jupyterlab==3.1.6
18
+ jupyterlab-pygments==0.1.2
19
+ jupyterlab-server==2.7.0
20
+ matplotlib==3.4.2
21
+ matplotlib-inline==0.1.2
22
+ moviepy==1.0.3
23
+ notebook==6.4.3
24
+ numba==0.53.1
25
+ numpy==1.20.3
26
+ nvidia-ml-py3==7.352.0
27
+ opencv-python==4.5.3.56
28
+ pandas==1.3.2
29
+ pathtools==0.1.2
30
+ Pillow==8.2.0
31
+ pip==21.2.2
32
+ plotly==5.2.1
33
+ psutil==5.8.0
34
+ ptyprocess==0.7.0
35
+ pyaml==21.8.3
36
+ regex==2021.11.10
37
+ scikit-image==0.18.1
38
+ scikit-learn==1.0.2
39
+ scipy==1.6.2
40
+ seaborn==0.11.2
41
+ subprocess32==3.5.4
42
+ svgpathtools==1.4.1
43
+ svgwrite==1.4.1
44
+ torch==1.7.1
45
+ torch-tools==0.1.5
46
+ torchfile==0.1.0
47
+ torchvision==0.8.2
48
+ tqdm==4.62.1
49
+ visdom==0.1.8.9
50
+ wandb==0.12.0
51
+ webencodings==0.5.1
52
+ websocket-client==0.57.0
53
+ zipp==3.5.0
run_object_sketching.py ADDED
@@ -0,0 +1,158 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import sys
2
+ import warnings
3
+
4
+ warnings.filterwarnings('ignore')
5
+ warnings.simplefilter('ignore')
6
+
7
+ import argparse
8
+ import multiprocessing as mp
9
+ import os
10
+ import subprocess as sp
11
+ from shutil import copyfile
12
+
13
+ import numpy as np
14
+ import torch
15
+ from IPython.display import Image as Image_colab
16
+ from IPython.display import display, SVG, clear_output
17
+ from ipywidgets import IntSlider, Output, IntProgress, Button
18
+ import time
19
+
20
+ parser = argparse.ArgumentParser()
21
+ parser.add_argument("--target_file", type=str,
22
+ help="target image file, located in <target_images>")
23
+ parser.add_argument("--num_strokes", type=int, default=16,
24
+ help="number of strokes used to generate the sketch, this defines the level of abstraction.")
25
+ parser.add_argument("--num_iter", type=int, default=2001,
26
+ help="number of iterations")
27
+ parser.add_argument("--fix_scale", type=int, default=0,
28
+ help="if the target image is not squared, it is recommended to fix the scale")
29
+ parser.add_argument("--mask_object", type=int, default=0,
30
+ help="if the target image contains background, it's better to mask it out")
31
+ parser.add_argument("--num_sketches", type=int, default=3,
32
+ help="it is recommended to draw 3 sketches and automatically chose the best one")
33
+ parser.add_argument("--multiprocess", type=int, default=0,
34
+ help="recommended to use multiprocess if your computer has enough memory")
35
+ parser.add_argument('-colab', action='store_true')
36
+ parser.add_argument('-cpu', action='store_true')
37
+ parser.add_argument('-display', action='store_true')
38
+ parser.add_argument('--gpunum', type=int, default=0)
39
+
40
+ args = parser.parse_args()
41
+
42
+ multiprocess = not args.colab and args.num_sketches > 1 and args.multiprocess
43
+
44
+ abs_path = os.path.abspath(os.getcwd())
45
+
46
+ target = f"{abs_path}/target_images/{args.target_file}"
47
+ assert os.path.isfile(target), f"{target} does not exists!"
48
+
49
+ if not os.path.isfile(f"{abs_path}/U2Net_/saved_models/u2net.pth"):
50
+ sp.run(["gdown", "https://drive.google.com/uc?id=1ao1ovG1Qtx4b7EoskHXmi2E9rp5CHLcZ",
51
+ "-O", "U2Net_/saved_models/"])
52
+
53
+ test_name = os.path.splitext(args.target_file)[0]
54
+ output_dir = f"{abs_path}/output_sketches/{test_name}/"
55
+ if not os.path.exists(output_dir):
56
+ os.makedirs(output_dir)
57
+
58
+ num_iter = args.num_iter
59
+ save_interval = 10
60
+ use_gpu = not args.cpu
61
+
62
+ if not torch.cuda.is_available():
63
+ use_gpu = False
64
+ print("CUDA is not configured with GPU, running with CPU instead.")
65
+ print("Note that this will be very slow, it is recommended to use colab.")
66
+
67
+ if args.colab:
68
+ print("=" * 50)
69
+ print(f"Processing [{args.target_file}] ...")
70
+ if args.colab or args.display:
71
+ img_ = Image_colab(target)
72
+ display(img_)
73
+ print(f"GPU: {use_gpu}, {torch.cuda.current_device()}")
74
+ print(f"Results will be saved to \n[{output_dir}] ...")
75
+ print("=" * 50)
76
+
77
+ seeds = list(range(0, args.num_sketches * 1000, 1000))
78
+
79
+ exit_codes = []
80
+ manager = mp.Manager()
81
+ losses_all = manager.dict()
82
+
83
+
84
+ def run(seed, wandb_name):
85
+ exit_code = sp.run(["python", "painterly_rendering.py", target,
86
+ "--num_paths", str(args.num_strokes),
87
+ "--output_dir", output_dir,
88
+ "--wandb_name", wandb_name,
89
+ "--num_iter", str(num_iter),
90
+ "--save_interval", str(save_interval),
91
+ "--seed", str(seed),
92
+ "--use_gpu", str(int(use_gpu)),
93
+ "--fix_scale", str(args.fix_scale),
94
+ "--mask_object", str(args.mask_object),
95
+ "--mask_object_attention", str(
96
+ args.mask_object),
97
+ "--display_logs", str(int(args.colab)),
98
+ "--display", str(int(args.display))])
99
+ if exit_code.returncode:
100
+ sys.exit(1)
101
+
102
+ config = np.load(f"{output_dir}/{wandb_name}/config.npy",
103
+ allow_pickle=True)[()]
104
+ loss_eval = np.array(config['loss_eval'])
105
+ inds = np.argsort(loss_eval)
106
+ losses_all[wandb_name] = loss_eval[inds][0]
107
+
108
+
109
+ def display_(seed, wandb_name):
110
+ path_to_svg = f"{output_dir}/{wandb_name}/svg_logs/"
111
+ intervals_ = list(range(0, num_iter, save_interval))
112
+ filename = f"svg_iter0.svg"
113
+ display(IntSlider())
114
+ out = Output()
115
+ display(out)
116
+ for i in intervals_:
117
+ filename = f"svg_iter{i}.svg"
118
+ not_exist = True
119
+ while not_exist:
120
+ not_exist = not os.path.isfile(f"{path_to_svg}/{filename}")
121
+ continue
122
+ with out:
123
+ clear_output()
124
+ print("")
125
+ display(IntProgress(
126
+ value=i,
127
+ min=0,
128
+ max=num_iter,
129
+ description='Processing:',
130
+ bar_style='info', # 'success', 'info', 'warning', 'danger' or ''
131
+ style={'bar_color': 'maroon'},
132
+ orientation='horizontal'
133
+ ))
134
+ display(SVG(f"{path_to_svg}/svg_iter{i}.svg"))
135
+
136
+
137
+
138
+ if multiprocess:
139
+ ncpus = 10
140
+ P = mp.Pool(ncpus) # Generate pool of workers
141
+
142
+ for seed in seeds:
143
+ wandb_name = f"{test_name}_{args.num_strokes}strokes_seed{seed}"
144
+ if multiprocess:
145
+ P.apply_async(run, (seed, wandb_name))
146
+ else:
147
+ run(seed, wandb_name)
148
+
149
+ if args.display:
150
+ time.sleep(10)
151
+ P.apply_async(display_, (0, f"{test_name}_{args.num_strokes}strokes_seed0"))
152
+
153
+ if multiprocess:
154
+ P.close()
155
+ P.join() # start processes
156
+ sorted_final = dict(sorted(losses_all.items(), key=lambda item: item[1]))
157
+ copyfile(f"{output_dir}/{list(sorted_final.keys())[0]}/best_iter.svg",
158
+ f"{output_dir}/{list(sorted_final.keys())[0]}_best.svg")
sketch_utils.py ADDED
@@ -0,0 +1,295 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+
3
+ import imageio
4
+ import matplotlib.pyplot as plt
5
+ import numpy as np
6
+ import pandas as pd
7
+ import pydiffvg
8
+ import skimage
9
+ import skimage.io
10
+ import torch
11
+ import wandb
12
+ import PIL
13
+ from PIL import Image
14
+ from torchvision import transforms
15
+ from torchvision.utils import make_grid
16
+ from skimage.transform import resize
17
+
18
+ from U2Net_.model import U2NET
19
+
20
+
21
+ def imwrite(img, filename, gamma=2.2, normalize=False, use_wandb=False, wandb_name="", step=0, input_im=None):
22
+ directory = os.path.dirname(filename)
23
+ if directory != '' and not os.path.exists(directory):
24
+ os.makedirs(directory)
25
+
26
+ if not isinstance(img, np.ndarray):
27
+ img = img.data.numpy()
28
+ if normalize:
29
+ img_rng = np.max(img) - np.min(img)
30
+ if img_rng > 0:
31
+ img = (img - np.min(img)) / img_rng
32
+ img = np.clip(img, 0.0, 1.0)
33
+ if img.ndim == 2:
34
+ # repeat along the third dimension
35
+ img = np.expand_dims(img, 2)
36
+ img[:, :, :3] = np.power(img[:, :, :3], 1.0/gamma)
37
+ img = (img * 255).astype(np.uint8)
38
+
39
+ skimage.io.imsave(filename, img, check_contrast=False)
40
+ images = [wandb.Image(Image.fromarray(img), caption="output")]
41
+ if input_im is not None and step == 0:
42
+ images.append(wandb.Image(input_im, caption="input"))
43
+ if use_wandb:
44
+ wandb.log({wandb_name + "_": images}, step=step)
45
+
46
+
47
+ def plot_batch(inputs, outputs, output_dir, step, use_wandb, title):
48
+ plt.figure()
49
+ plt.subplot(2, 1, 1)
50
+ grid = make_grid(inputs.clone().detach(), normalize=True, pad_value=2)
51
+ npgrid = grid.cpu().numpy()
52
+ plt.imshow(np.transpose(npgrid, (1, 2, 0)), interpolation='nearest')
53
+ plt.axis("off")
54
+ plt.title("inputs")
55
+
56
+ plt.subplot(2, 1, 2)
57
+ grid = make_grid(outputs, normalize=False, pad_value=2)
58
+ npgrid = grid.detach().cpu().numpy()
59
+ plt.imshow(np.transpose(npgrid, (1, 2, 0)), interpolation='nearest')
60
+ plt.axis("off")
61
+ plt.title("outputs")
62
+
63
+ plt.tight_layout()
64
+ if use_wandb:
65
+ wandb.log({"output": wandb.Image(plt)}, step=step)
66
+ plt.savefig("{}/{}".format(output_dir, title))
67
+ plt.close()
68
+
69
+
70
+ def log_input(use_wandb, epoch, inputs, output_dir):
71
+ grid = make_grid(inputs.clone().detach(), normalize=True, pad_value=2)
72
+ npgrid = grid.cpu().numpy()
73
+ plt.imshow(np.transpose(npgrid, (1, 2, 0)), interpolation='nearest')
74
+ plt.axis("off")
75
+ plt.tight_layout()
76
+ if use_wandb:
77
+ wandb.log({"input": wandb.Image(plt)}, step=epoch)
78
+ plt.close()
79
+ input_ = inputs[0].cpu().clone().detach().permute(1, 2, 0).numpy()
80
+ input_ = (input_ - input_.min()) / (input_.max() - input_.min())
81
+ input_ = (input_ * 255).astype(np.uint8)
82
+ imageio.imwrite("{}/{}.png".format(output_dir, "input"), input_)
83
+
84
+
85
+ def log_sketch_summary_final(path_svg, use_wandb, device, epoch, loss, title):
86
+ canvas_width, canvas_height, shapes, shape_groups = load_svg(path_svg)
87
+ _render = pydiffvg.RenderFunction.apply
88
+ scene_args = pydiffvg.RenderFunction.serialize_scene(
89
+ canvas_width, canvas_height, shapes, shape_groups)
90
+ img = _render(canvas_width, # width
91
+ canvas_height, # height
92
+ 2, # num_samples_x
93
+ 2, # num_samples_y
94
+ 0, # seed
95
+ None,
96
+ *scene_args)
97
+
98
+ img = img[:, :, 3:4] * img[:, :, :3] + \
99
+ torch.ones(img.shape[0], img.shape[1], 3,
100
+ device=device) * (1 - img[:, :, 3:4])
101
+ img = img[:, :, :3]
102
+ plt.imshow(img.cpu().numpy())
103
+ plt.axis("off")
104
+ plt.title(f"{title} best res [{epoch}] [{loss}.]")
105
+ if use_wandb:
106
+ wandb.log({title: wandb.Image(plt)})
107
+ plt.close()
108
+
109
+
110
+ def log_sketch_summary(sketch, title, use_wandb):
111
+ plt.figure()
112
+ grid = make_grid(sketch.clone().detach(), normalize=True, pad_value=2)
113
+ npgrid = grid.cpu().numpy()
114
+ plt.imshow(np.transpose(npgrid, (1, 2, 0)), interpolation='nearest')
115
+ plt.axis("off")
116
+ plt.title(title)
117
+ plt.tight_layout()
118
+ if use_wandb:
119
+ wandb.run.summary["best_loss_im"] = wandb.Image(plt)
120
+ plt.close()
121
+
122
+
123
+ def load_svg(path_svg):
124
+ svg = os.path.join(path_svg)
125
+ canvas_width, canvas_height, shapes, shape_groups = pydiffvg.svg_to_scene(
126
+ svg)
127
+ return canvas_width, canvas_height, shapes, shape_groups
128
+
129
+
130
+ def read_svg(path_svg, device, multiply=False):
131
+ canvas_width, canvas_height, shapes, shape_groups = pydiffvg.svg_to_scene(
132
+ path_svg)
133
+ if multiply:
134
+ canvas_width *= 2
135
+ canvas_height *= 2
136
+ for path in shapes:
137
+ path.points *= 2
138
+ path.stroke_width *= 2
139
+ _render = pydiffvg.RenderFunction.apply
140
+ scene_args = pydiffvg.RenderFunction.serialize_scene(
141
+ canvas_width, canvas_height, shapes, shape_groups)
142
+ img = _render(canvas_width, # width
143
+ canvas_height, # height
144
+ 2, # num_samples_x
145
+ 2, # num_samples_y
146
+ 0, # seed
147
+ None,
148
+ *scene_args)
149
+ img = img[:, :, 3:4] * img[:, :, :3] + \
150
+ torch.ones(img.shape[0], img.shape[1], 3,
151
+ device=device) * (1 - img[:, :, 3:4])
152
+ img = img[:, :, :3]
153
+ return img
154
+
155
+
156
+ def plot_attn_dino(attn, threshold_map, inputs, inds, use_wandb, output_path):
157
+ # currently supports one image (and not a batch)
158
+ plt.figure(figsize=(10, 5))
159
+
160
+ plt.subplot(2, attn.shape[0] + 2, 1)
161
+ main_im = make_grid(inputs, normalize=True, pad_value=2)
162
+ main_im = np.transpose(main_im.cpu().numpy(), (1, 2, 0))
163
+ plt.imshow(main_im, interpolation='nearest')
164
+ plt.scatter(inds[:, 1], inds[:, 0], s=10, c='red', marker='o')
165
+ plt.title("input im")
166
+ plt.axis("off")
167
+
168
+ plt.subplot(2, attn.shape[0] + 2, 2)
169
+ plt.imshow(attn.sum(0).numpy(), interpolation='nearest')
170
+ plt.title("atn map sum")
171
+ plt.axis("off")
172
+
173
+ plt.subplot(2, attn.shape[0] + 2, attn.shape[0] + 3)
174
+ plt.imshow(threshold_map[-1].numpy(), interpolation='nearest')
175
+ plt.title("prob sum")
176
+ plt.axis("off")
177
+
178
+ plt.subplot(2, attn.shape[0] + 2, attn.shape[0] + 4)
179
+ plt.imshow(threshold_map[:-1].sum(0).numpy(), interpolation='nearest')
180
+ plt.title("thresh sum")
181
+ plt.axis("off")
182
+
183
+ for i in range(attn.shape[0]):
184
+ plt.subplot(2, attn.shape[0] + 2, i + 3)
185
+ plt.imshow(attn[i].numpy())
186
+ plt.axis("off")
187
+ plt.subplot(2, attn.shape[0] + 2, attn.shape[0] + 1 + i + 4)
188
+ plt.imshow(threshold_map[i].numpy())
189
+ plt.axis("off")
190
+ plt.tight_layout()
191
+ if use_wandb:
192
+ wandb.log({"attention_map": wandb.Image(plt)})
193
+ plt.savefig(output_path)
194
+ plt.close()
195
+
196
+
197
+ def plot_attn_clip(attn, threshold_map, inputs, inds, use_wandb, output_path, display_logs):
198
+ # currently supports one image (and not a batch)
199
+ plt.figure(figsize=(10, 5))
200
+
201
+ plt.subplot(1, 3, 1)
202
+ main_im = make_grid(inputs, normalize=True, pad_value=2)
203
+ main_im = np.transpose(main_im.cpu().numpy(), (1, 2, 0))
204
+ plt.imshow(main_im, interpolation='nearest')
205
+ plt.scatter(inds[:, 1], inds[:, 0], s=10, c='red', marker='o')
206
+ plt.title("input im")
207
+ plt.axis("off")
208
+
209
+ plt.subplot(1, 3, 2)
210
+ plt.imshow(attn, interpolation='nearest', vmin=0, vmax=1)
211
+ plt.title("atn map")
212
+ plt.axis("off")
213
+
214
+ plt.subplot(1, 3, 3)
215
+ threshold_map_ = (threshold_map - threshold_map.min()) / \
216
+ (threshold_map.max() - threshold_map.min())
217
+ plt.imshow(threshold_map_, interpolation='nearest', vmin=0, vmax=1)
218
+ plt.title("prob softmax")
219
+ plt.scatter(inds[:, 1], inds[:, 0], s=10, c='red', marker='o')
220
+ plt.axis("off")
221
+
222
+ plt.tight_layout()
223
+ if use_wandb:
224
+ wandb.log({"attention_map": wandb.Image(plt)})
225
+ plt.savefig(output_path)
226
+ plt.close()
227
+
228
+
229
+ def plot_atten(attn, threshold_map, inputs, inds, use_wandb, output_path, saliency_model, display_logs):
230
+ if saliency_model == "dino":
231
+ plot_attn_dino(attn, threshold_map, inputs,
232
+ inds, use_wandb, output_path)
233
+ elif saliency_model == "clip":
234
+ plot_attn_clip(attn, threshold_map, inputs, inds,
235
+ use_wandb, output_path, display_logs)
236
+
237
+
238
+ def fix_image_scale(im):
239
+ im_np = np.array(im) / 255
240
+ height, width = im_np.shape[0], im_np.shape[1]
241
+ max_len = max(height, width) + 20
242
+ new_background = np.ones((max_len, max_len, 3))
243
+ y, x = max_len // 2 - height // 2, max_len // 2 - width // 2
244
+ new_background[y: y + height, x: x + width] = im_np
245
+ new_background = (new_background / new_background.max()
246
+ * 255).astype(np.uint8)
247
+ new_im = Image.fromarray(new_background)
248
+ return new_im
249
+
250
+
251
+ def get_mask_u2net(args, pil_im):
252
+ w, h = pil_im.size[0], pil_im.size[1]
253
+ im_size = min(w, h)
254
+ data_transforms = transforms.Compose([
255
+ transforms.Resize(min(320, im_size), interpolation=PIL.Image.BICUBIC),
256
+ transforms.ToTensor(),
257
+ transforms.Normalize(mean=(0.48145466, 0.4578275, 0.40821073), std=(
258
+ 0.26862954, 0.26130258, 0.27577711)),
259
+ ])
260
+
261
+ input_im_trans = data_transforms(pil_im).unsqueeze(0).to(args.device)
262
+
263
+ model_dir = os.path.join("./U2Net_/saved_models/u2net.pth")
264
+ net = U2NET(3, 1)
265
+ if torch.cuda.is_available() and args.use_gpu:
266
+ net.load_state_dict(torch.load(model_dir))
267
+ net.to(args.device)
268
+ else:
269
+ net.load_state_dict(torch.load(model_dir, map_location='cpu'))
270
+ net.eval()
271
+ with torch.no_grad():
272
+ d1, d2, d3, d4, d5, d6, d7 = net(input_im_trans.detach())
273
+ pred = d1[:, 0, :, :]
274
+ pred = (pred - pred.min()) / (pred.max() - pred.min())
275
+ predict = pred
276
+ predict[predict < 0.5] = 0
277
+ predict[predict >= 0.5] = 1
278
+ mask = torch.cat([predict, predict, predict], axis=0).permute(1, 2, 0)
279
+ mask = mask.cpu().numpy()
280
+ mask = resize(mask, (h, w), anti_aliasing=False)
281
+ mask[mask < 0.5] = 0
282
+ mask[mask >= 0.5] = 1
283
+
284
+ # predict_np = predict.clone().cpu().data.numpy()
285
+ im = Image.fromarray((mask[:, :, 0]*255).astype(np.uint8)).convert('RGB')
286
+ im.save(f"{args.output_dir}/mask.png")
287
+
288
+ im_np = np.array(pil_im)
289
+ im_np = im_np / im_np.max()
290
+ im_np = mask * im_np
291
+ im_np[mask == 0] = 1
292
+ im_final = (im_np / im_np.max() * 255).astype(np.uint8)
293
+ im_final = Image.fromarray(im_final)
294
+
295
+ return im_final, predict
target_images/camel.png ADDED
target_images/flamingo.png ADDED
target_images/horse.png ADDED
target_images/rose.jpeg ADDED