akhaliq HF staff commited on
Commit
c80917c
1 Parent(s): 81832da
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .DS_Store +0 -0
  2. LICENSE +437 -0
  3. captioning/__init__.py +0 -0
  4. captioning/data/__init__.py +0 -0
  5. captioning/data/dataloader.py +425 -0
  6. captioning/data/pth_loader.py +334 -0
  7. captioning/data/pth_loader_FineCapEval.py +334 -0
  8. captioning/models/AoAModel.py +228 -0
  9. captioning/models/AttEnsemble.py +90 -0
  10. captioning/models/AttModel.py +969 -0
  11. captioning/models/BertCapModel.py +104 -0
  12. captioning/models/CaptionModel.py +407 -0
  13. captioning/models/FCModel.py +204 -0
  14. captioning/models/M2Transformer.py +98 -0
  15. captioning/models/ShowTellModel.py +174 -0
  16. captioning/models/TransformerModel.py +363 -0
  17. captioning/models/__init__.py +73 -0
  18. captioning/models/cachedTransformer.py +420 -0
  19. captioning/models/utils.py +25 -0
  20. captioning/modules/loss_wrapper.py +127 -0
  21. captioning/modules/losses.py +218 -0
  22. captioning/utils/__init__.py +0 -0
  23. captioning/utils/clipscore.py +396 -0
  24. captioning/utils/config.py +153 -0
  25. captioning/utils/dist_utils.py +305 -0
  26. captioning/utils/div_utils.py +38 -0
  27. captioning/utils/eval_multi.py +218 -0
  28. captioning/utils/eval_utils.py +281 -0
  29. captioning/utils/misc.py +251 -0
  30. captioning/utils/opts.py +412 -0
  31. captioning/utils/resnet.py +71 -0
  32. captioning/utils/resnet_utils.py +27 -0
  33. captioning/utils/rewards.py +392 -0
  34. captioning/utils/utils.py +138 -0
  35. clip/__init__.py +1 -0
  36. clip/clip.py +193 -0
  37. clip/model.py +437 -0
  38. clip/simple_tokenizer.py +132 -0
  39. cog.yaml +26 -0
  40. configs/phase1/FineCapEval_clipRN50_mle.yml +60 -0
  41. configs/phase1/clipRN50_mle.yml +52 -0
  42. configs/phase1/transformer.yml +41 -0
  43. configs/phase2/FineCapEval_clipRN50_cider.yml +61 -0
  44. configs/phase2/FineCapEval_clipRN50_cider_clips.yml +65 -0
  45. configs/phase2/FineCapEval_clipRN50_clips.yml +64 -0
  46. configs/phase2/FineCapEval_clipRN50_clips_grammar.yml +64 -0
  47. configs/phase2/clipRN50_cider.yml +58 -0
  48. configs/phase2/clipRN50_cider_clips.yml +61 -0
  49. configs/phase2/clipRN50_clips.yml +58 -0
  50. configs/phase2/clipRN50_clips_grammar.yml +64 -0
.DS_Store ADDED
Binary file (6.15 kB). View file
 
LICENSE ADDED
@@ -0,0 +1,437 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Attribution-NonCommercial-ShareAlike 4.0 International
2
+
3
+ =======================================================================
4
+
5
+ Creative Commons Corporation ("Creative Commons") is not a law firm and
6
+ does not provide legal services or legal advice. Distribution of
7
+ Creative Commons public licenses does not create a lawyer-client or
8
+ other relationship. Creative Commons makes its licenses and related
9
+ information available on an "as-is" basis. Creative Commons gives no
10
+ warranties regarding its licenses, any material licensed under their
11
+ terms and conditions, or any related information. Creative Commons
12
+ disclaims all liability for damages resulting from their use to the
13
+ fullest extent possible.
14
+
15
+ Using Creative Commons Public Licenses
16
+
17
+ Creative Commons public licenses provide a standard set of terms and
18
+ conditions that creators and other rights holders may use to share
19
+ original works of authorship and other material subject to copyright
20
+ and certain other rights specified in the public license below. The
21
+ following considerations are for informational purposes only, are not
22
+ exhaustive, and do not form part of our licenses.
23
+
24
+ Considerations for licensors: Our public licenses are
25
+ intended for use by those authorized to give the public
26
+ permission to use material in ways otherwise restricted by
27
+ copyright and certain other rights. Our licenses are
28
+ irrevocable. Licensors should read and understand the terms
29
+ and conditions of the license they choose before applying it.
30
+ Licensors should also secure all rights necessary before
31
+ applying our licenses so that the public can reuse the
32
+ material as expected. Licensors should clearly mark any
33
+ material not subject to the license. This includes other CC-
34
+ licensed material, or material used under an exception or
35
+ limitation to copyright. More considerations for licensors:
36
+ wiki.creativecommons.org/Considerations_for_licensors
37
+
38
+ Considerations for the public: By using one of our public
39
+ licenses, a licensor grants the public permission to use the
40
+ licensed material under specified terms and conditions. If
41
+ the licensor's permission is not necessary for any reason--for
42
+ example, because of any applicable exception or limitation to
43
+ copyright--then that use is not regulated by the license. Our
44
+ licenses grant only permissions under copyright and certain
45
+ other rights that a licensor has authority to grant. Use of
46
+ the licensed material may still be restricted for other
47
+ reasons, including because others have copyright or other
48
+ rights in the material. A licensor may make special requests,
49
+ such as asking that all changes be marked or described.
50
+ Although not required by our licenses, you are encouraged to
51
+ respect those requests where reasonable. More considerations
52
+ for the public:
53
+ wiki.creativecommons.org/Considerations_for_licensees
54
+
55
+ =======================================================================
56
+
57
+ Creative Commons Attribution-NonCommercial-ShareAlike 4.0 International
58
+ Public License
59
+
60
+ By exercising the Licensed Rights (defined below), You accept and agree
61
+ to be bound by the terms and conditions of this Creative Commons
62
+ Attribution-NonCommercial-ShareAlike 4.0 International Public License
63
+ ("Public License"). To the extent this Public License may be
64
+ interpreted as a contract, You are granted the Licensed Rights in
65
+ consideration of Your acceptance of these terms and conditions, and the
66
+ Licensor grants You such rights in consideration of benefits the
67
+ Licensor receives from making the Licensed Material available under
68
+ these terms and conditions.
69
+
70
+
71
+ Section 1 -- Definitions.
72
+
73
+ a. Adapted Material means material subject to Copyright and Similar
74
+ Rights that is derived from or based upon the Licensed Material
75
+ and in which the Licensed Material is translated, altered,
76
+ arranged, transformed, or otherwise modified in a manner requiring
77
+ permission under the Copyright and Similar Rights held by the
78
+ Licensor. For purposes of this Public License, where the Licensed
79
+ Material is a musical work, performance, or sound recording,
80
+ Adapted Material is always produced where the Licensed Material is
81
+ synched in timed relation with a moving image.
82
+
83
+ b. Adapter's License means the license You apply to Your Copyright
84
+ and Similar Rights in Your contributions to Adapted Material in
85
+ accordance with the terms and conditions of this Public License.
86
+
87
+ c. BY-NC-SA Compatible License means a license listed at
88
+ creativecommons.org/compatiblelicenses, approved by Creative
89
+ Commons as essentially the equivalent of this Public License.
90
+
91
+ d. Copyright and Similar Rights means copyright and/or similar rights
92
+ closely related to copyright including, without limitation,
93
+ performance, broadcast, sound recording, and Sui Generis Database
94
+ Rights, without regard to how the rights are labeled or
95
+ categorized. For purposes of this Public License, the rights
96
+ specified in Section 2(b)(1)-(2) are not Copyright and Similar
97
+ Rights.
98
+
99
+ e. Effective Technological Measures means those measures that, in the
100
+ absence of proper authority, may not be circumvented under laws
101
+ fulfilling obligations under Article 11 of the WIPO Copyright
102
+ Treaty adopted on December 20, 1996, and/or similar international
103
+ agreements.
104
+
105
+ f. Exceptions and Limitations means fair use, fair dealing, and/or
106
+ any other exception or limitation to Copyright and Similar Rights
107
+ that applies to Your use of the Licensed Material.
108
+
109
+ g. License Elements means the license attributes listed in the name
110
+ of a Creative Commons Public License. The License Elements of this
111
+ Public License are Attribution, NonCommercial, and ShareAlike.
112
+
113
+ h. Licensed Material means the artistic or literary work, database,
114
+ or other material to which the Licensor applied this Public
115
+ License.
116
+
117
+ i. Licensed Rights means the rights granted to You subject to the
118
+ terms and conditions of this Public License, which are limited to
119
+ all Copyright and Similar Rights that apply to Your use of the
120
+ Licensed Material and that the Licensor has authority to license.
121
+
122
+ j. Licensor means the individual(s) or entity(ies) granting rights
123
+ under this Public License.
124
+
125
+ k. NonCommercial means not primarily intended for or directed towards
126
+ commercial advantage or monetary compensation. For purposes of
127
+ this Public License, the exchange of the Licensed Material for
128
+ other material subject to Copyright and Similar Rights by digital
129
+ file-sharing or similar means is NonCommercial provided there is
130
+ no payment of monetary compensation in connection with the
131
+ exchange.
132
+
133
+ l. Share means to provide material to the public by any means or
134
+ process that requires permission under the Licensed Rights, such
135
+ as reproduction, public display, public performance, distribution,
136
+ dissemination, communication, or importation, and to make material
137
+ available to the public including in ways that members of the
138
+ public may access the material from a place and at a time
139
+ individually chosen by them.
140
+
141
+ m. Sui Generis Database Rights means rights other than copyright
142
+ resulting from Directive 96/9/EC of the European Parliament and of
143
+ the Council of 11 March 1996 on the legal protection of databases,
144
+ as amended and/or succeeded, as well as other essentially
145
+ equivalent rights anywhere in the world.
146
+
147
+ n. You means the individual or entity exercising the Licensed Rights
148
+ under this Public License. Your has a corresponding meaning.
149
+
150
+
151
+ Section 2 -- Scope.
152
+
153
+ a. License grant.
154
+
155
+ 1. Subject to the terms and conditions of this Public License,
156
+ the Licensor hereby grants You a worldwide, royalty-free,
157
+ non-sublicensable, non-exclusive, irrevocable license to
158
+ exercise the Licensed Rights in the Licensed Material to:
159
+
160
+ a. reproduce and Share the Licensed Material, in whole or
161
+ in part, for NonCommercial purposes only; and
162
+
163
+ b. produce, reproduce, and Share Adapted Material for
164
+ NonCommercial purposes only.
165
+
166
+ 2. Exceptions and Limitations. For the avoidance of doubt, where
167
+ Exceptions and Limitations apply to Your use, this Public
168
+ License does not apply, and You do not need to comply with
169
+ its terms and conditions.
170
+
171
+ 3. Term. The term of this Public License is specified in Section
172
+ 6(a).
173
+
174
+ 4. Media and formats; technical modifications allowed. The
175
+ Licensor authorizes You to exercise the Licensed Rights in
176
+ all media and formats whether now known or hereafter created,
177
+ and to make technical modifications necessary to do so. The
178
+ Licensor waives and/or agrees not to assert any right or
179
+ authority to forbid You from making technical modifications
180
+ necessary to exercise the Licensed Rights, including
181
+ technical modifications necessary to circumvent Effective
182
+ Technological Measures. For purposes of this Public License,
183
+ simply making modifications authorized by this Section 2(a)
184
+ (4) never produces Adapted Material.
185
+
186
+ 5. Downstream recipients.
187
+
188
+ a. Offer from the Licensor -- Licensed Material. Every
189
+ recipient of the Licensed Material automatically
190
+ receives an offer from the Licensor to exercise the
191
+ Licensed Rights under the terms and conditions of this
192
+ Public License.
193
+
194
+ b. Additional offer from the Licensor -- Adapted Material.
195
+ Every recipient of Adapted Material from You
196
+ automatically receives an offer from the Licensor to
197
+ exercise the Licensed Rights in the Adapted Material
198
+ under the conditions of the Adapter's License You apply.
199
+
200
+ c. No downstream restrictions. You may not offer or impose
201
+ any additional or different terms or conditions on, or
202
+ apply any Effective Technological Measures to, the
203
+ Licensed Material if doing so restricts exercise of the
204
+ Licensed Rights by any recipient of the Licensed
205
+ Material.
206
+
207
+ 6. No endorsement. Nothing in this Public License constitutes or
208
+ may be construed as permission to assert or imply that You
209
+ are, or that Your use of the Licensed Material is, connected
210
+ with, or sponsored, endorsed, or granted official status by,
211
+ the Licensor or others designated to receive attribution as
212
+ provided in Section 3(a)(1)(A)(i).
213
+
214
+ b. Other rights.
215
+
216
+ 1. Moral rights, such as the right of integrity, are not
217
+ licensed under this Public License, nor are publicity,
218
+ privacy, and/or other similar personality rights; however, to
219
+ the extent possible, the Licensor waives and/or agrees not to
220
+ assert any such rights held by the Licensor to the limited
221
+ extent necessary to allow You to exercise the Licensed
222
+ Rights, but not otherwise.
223
+
224
+ 2. Patent and trademark rights are not licensed under this
225
+ Public License.
226
+
227
+ 3. To the extent possible, the Licensor waives any right to
228
+ collect royalties from You for the exercise of the Licensed
229
+ Rights, whether directly or through a collecting society
230
+ under any voluntary or waivable statutory or compulsory
231
+ licensing scheme. In all other cases the Licensor expressly
232
+ reserves any right to collect such royalties, including when
233
+ the Licensed Material is used other than for NonCommercial
234
+ purposes.
235
+
236
+
237
+ Section 3 -- License Conditions.
238
+
239
+ Your exercise of the Licensed Rights is expressly made subject to the
240
+ following conditions.
241
+
242
+ a. Attribution.
243
+
244
+ 1. If You Share the Licensed Material (including in modified
245
+ form), You must:
246
+
247
+ a. retain the following if it is supplied by the Licensor
248
+ with the Licensed Material:
249
+
250
+ i. identification of the creator(s) of the Licensed
251
+ Material and any others designated to receive
252
+ attribution, in any reasonable manner requested by
253
+ the Licensor (including by pseudonym if
254
+ designated);
255
+
256
+ ii. a copyright notice;
257
+
258
+ iii. a notice that refers to this Public License;
259
+
260
+ iv. a notice that refers to the disclaimer of
261
+ warranties;
262
+
263
+ v. a URI or hyperlink to the Licensed Material to the
264
+ extent reasonably practicable;
265
+
266
+ b. indicate if You modified the Licensed Material and
267
+ retain an indication of any previous modifications; and
268
+
269
+ c. indicate the Licensed Material is licensed under this
270
+ Public License, and include the text of, or the URI or
271
+ hyperlink to, this Public License.
272
+
273
+ 2. You may satisfy the conditions in Section 3(a)(1) in any
274
+ reasonable manner based on the medium, means, and context in
275
+ which You Share the Licensed Material. For example, it may be
276
+ reasonable to satisfy the conditions by providing a URI or
277
+ hyperlink to a resource that includes the required
278
+ information.
279
+ 3. If requested by the Licensor, You must remove any of the
280
+ information required by Section 3(a)(1)(A) to the extent
281
+ reasonably practicable.
282
+
283
+ b. ShareAlike.
284
+
285
+ In addition to the conditions in Section 3(a), if You Share
286
+ Adapted Material You produce, the following conditions also apply.
287
+
288
+ 1. The Adapter's License You apply must be a Creative Commons
289
+ license with the same License Elements, this version or
290
+ later, or a BY-NC-SA Compatible License.
291
+
292
+ 2. You must include the text of, or the URI or hyperlink to, the
293
+ Adapter's License You apply. You may satisfy this condition
294
+ in any reasonable manner based on the medium, means, and
295
+ context in which You Share Adapted Material.
296
+
297
+ 3. You may not offer or impose any additional or different terms
298
+ or conditions on, or apply any Effective Technological
299
+ Measures to, Adapted Material that restrict exercise of the
300
+ rights granted under the Adapter's License You apply.
301
+
302
+
303
+ Section 4 -- Sui Generis Database Rights.
304
+
305
+ Where the Licensed Rights include Sui Generis Database Rights that
306
+ apply to Your use of the Licensed Material:
307
+
308
+ a. for the avoidance of doubt, Section 2(a)(1) grants You the right
309
+ to extract, reuse, reproduce, and Share all or a substantial
310
+ portion of the contents of the database for NonCommercial purposes
311
+ only;
312
+
313
+ b. if You include all or a substantial portion of the database
314
+ contents in a database in which You have Sui Generis Database
315
+ Rights, then the database in which You have Sui Generis Database
316
+ Rights (but not its individual contents) is Adapted Material,
317
+ including for purposes of Section 3(b); and
318
+
319
+ c. You must comply with the conditions in Section 3(a) if You Share
320
+ all or a substantial portion of the contents of the database.
321
+
322
+ For the avoidance of doubt, this Section 4 supplements and does not
323
+ replace Your obligations under this Public License where the Licensed
324
+ Rights include other Copyright and Similar Rights.
325
+
326
+
327
+ Section 5 -- Disclaimer of Warranties and Limitation of Liability.
328
+
329
+ a. UNLESS OTHERWISE SEPARATELY UNDERTAKEN BY THE LICENSOR, TO THE
330
+ EXTENT POSSIBLE, THE LICENSOR OFFERS THE LICENSED MATERIAL AS-IS
331
+ AND AS-AVAILABLE, AND MAKES NO REPRESENTATIONS OR WARRANTIES OF
332
+ ANY KIND CONCERNING THE LICENSED MATERIAL, WHETHER EXPRESS,
333
+ IMPLIED, STATUTORY, OR OTHER. THIS INCLUDES, WITHOUT LIMITATION,
334
+ WARRANTIES OF TITLE, MERCHANTABILITY, FITNESS FOR A PARTICULAR
335
+ PURPOSE, NON-INFRINGEMENT, ABSENCE OF LATENT OR OTHER DEFECTS,
336
+ ACCURACY, OR THE PRESENCE OR ABSENCE OF ERRORS, WHETHER OR NOT
337
+ KNOWN OR DISCOVERABLE. WHERE DISCLAIMERS OF WARRANTIES ARE NOT
338
+ ALLOWED IN FULL OR IN PART, THIS DISCLAIMER MAY NOT APPLY TO YOU.
339
+
340
+ b. TO THE EXTENT POSSIBLE, IN NO EVENT WILL THE LICENSOR BE LIABLE
341
+ TO YOU ON ANY LEGAL THEORY (INCLUDING, WITHOUT LIMITATION,
342
+ NEGLIGENCE) OR OTHERWISE FOR ANY DIRECT, SPECIAL, INDIRECT,
343
+ INCIDENTAL, CONSEQUENTIAL, PUNITIVE, EXEMPLARY, OR OTHER LOSSES,
344
+ COSTS, EXPENSES, OR DAMAGES ARISING OUT OF THIS PUBLIC LICENSE OR
345
+ USE OF THE LICENSED MATERIAL, EVEN IF THE LICENSOR HAS BEEN
346
+ ADVISED OF THE POSSIBILITY OF SUCH LOSSES, COSTS, EXPENSES, OR
347
+ DAMAGES. WHERE A LIMITATION OF LIABILITY IS NOT ALLOWED IN FULL OR
348
+ IN PART, THIS LIMITATION MAY NOT APPLY TO YOU.
349
+
350
+ c. The disclaimer of warranties and limitation of liability provided
351
+ above shall be interpreted in a manner that, to the extent
352
+ possible, most closely approximates an absolute disclaimer and
353
+ waiver of all liability.
354
+
355
+
356
+ Section 6 -- Term and Termination.
357
+
358
+ a. This Public License applies for the term of the Copyright and
359
+ Similar Rights licensed here. However, if You fail to comply with
360
+ this Public License, then Your rights under this Public License
361
+ terminate automatically.
362
+
363
+ b. Where Your right to use the Licensed Material has terminated under
364
+ Section 6(a), it reinstates:
365
+
366
+ 1. automatically as of the date the violation is cured, provided
367
+ it is cured within 30 days of Your discovery of the
368
+ violation; or
369
+
370
+ 2. upon express reinstatement by the Licensor.
371
+
372
+ For the avoidance of doubt, this Section 6(b) does not affect any
373
+ right the Licensor may have to seek remedies for Your violations
374
+ of this Public License.
375
+
376
+ c. For the avoidance of doubt, the Licensor may also offer the
377
+ Licensed Material under separate terms or conditions or stop
378
+ distributing the Licensed Material at any time; however, doing so
379
+ will not terminate this Public License.
380
+
381
+ d. Sections 1, 5, 6, 7, and 8 survive termination of this Public
382
+ License.
383
+
384
+
385
+ Section 7 -- Other Terms and Conditions.
386
+
387
+ a. The Licensor shall not be bound by any additional or different
388
+ terms or conditions communicated by You unless expressly agreed.
389
+
390
+ b. Any arrangements, understandings, or agreements regarding the
391
+ Licensed Material not stated herein are separate from and
392
+ independent of the terms and conditions of this Public License.
393
+
394
+
395
+ Section 8 -- Interpretation.
396
+
397
+ a. For the avoidance of doubt, this Public License does not, and
398
+ shall not be interpreted to, reduce, limit, restrict, or impose
399
+ conditions on any use of the Licensed Material that could lawfully
400
+ be made without permission under this Public License.
401
+
402
+ b. To the extent possible, if any provision of this Public License is
403
+ deemed unenforceable, it shall be automatically reformed to the
404
+ minimum extent necessary to make it enforceable. If the provision
405
+ cannot be reformed, it shall be severed from this Public License
406
+ without affecting the enforceability of the remaining terms and
407
+ conditions.
408
+
409
+ c. No term or condition of this Public License will be waived and no
410
+ failure to comply consented to unless expressly agreed to by the
411
+ Licensor.
412
+
413
+ d. Nothing in this Public License constitutes or may be interpreted
414
+ as a limitation upon, or waiver of, any privileges and immunities
415
+ that apply to the Licensor or You, including from the legal
416
+ processes of any jurisdiction or authority.
417
+
418
+ =======================================================================
419
+
420
+ Creative Commons is not a party to its public
421
+ licenses. Notwithstanding, Creative Commons may elect to apply one of
422
+ its public licenses to material it publishes and in those instances
423
+ will be considered the “Licensor.” The text of the Creative Commons
424
+ public licenses is dedicated to the public domain under the CC0 Public
425
+ Domain Dedication. Except for the limited purpose of indicating that
426
+ material is shared under a Creative Commons public license or as
427
+ otherwise permitted by the Creative Commons policies published at
428
+ creativecommons.org/policies, Creative Commons does not authorize the
429
+ use of the trademark "Creative Commons" or any other trademark or logo
430
+ of Creative Commons without its prior written consent including,
431
+ without limitation, in connection with any unauthorized modifications
432
+ to any of its public licenses or any other arrangements,
433
+ understandings, or agreements concerning use of licensed material. For
434
+ the avoidance of doubt, this paragraph does not form part of the
435
+ public licenses.
436
+
437
+ Creative Commons may be contacted at creativecommons.org.
captioning/__init__.py ADDED
File without changes
captioning/data/__init__.py ADDED
File without changes
captioning/data/dataloader.py ADDED
@@ -0,0 +1,425 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import absolute_import
2
+ from __future__ import division
3
+ from __future__ import print_function
4
+
5
+ import json
6
+ import h5py
7
+ from lmdbdict import lmdbdict
8
+ from lmdbdict.methods import DUMPS_FUNC, LOADS_FUNC
9
+ import os
10
+ import numpy as np
11
+ import numpy.random as npr
12
+ import random
13
+ from functools import partial
14
+
15
+ import torch
16
+ import torch.utils.data as data
17
+
18
+ import multiprocessing
19
+ import six
20
+
21
+ class HybridLoader:
22
+ """
23
+ If db_path is a director, then use normal file loading
24
+ If lmdb, then load from lmdb
25
+ The loading method depend on extention.
26
+
27
+ in_memory: if in_memory is True, we save all the features in memory
28
+ For individual np(y|z)s, we don't need to do that because the system will do this for us.
29
+ Should be useful for lmdb or h5.
30
+ (Copied this idea from vilbert)
31
+ """
32
+ def __init__(self, db_path, ext, in_memory=False):
33
+ self.db_path = db_path
34
+ self.ext = ext
35
+ if self.ext == '.npy':
36
+ self.loader = lambda x: np.load(six.BytesIO(x))
37
+ else:
38
+ def load_npz(x):
39
+ x = np.load(six.BytesIO(x))
40
+ return x['feat'] if 'feat' in x else x['z'] # normally it should be 'feat', but under cocotest_bu, the key is saved to be 'z' mistakenly.
41
+ self.loader = load_npz
42
+ if db_path.endswith('.lmdb'):
43
+ self.db_type = 'lmdb'
44
+ self.lmdb = lmdbdict(db_path, unsafe=True)
45
+ self.lmdb._key_dumps = DUMPS_FUNC['ascii']
46
+ self.lmdb._value_loads = LOADS_FUNC['identity']
47
+ elif db_path.endswith('.pth'): # Assume a key,value dictionary
48
+ self.db_type = 'pth'
49
+ self.feat_file = torch.load(db_path)
50
+ self.loader = lambda x: x
51
+ print('HybridLoader: ext is ignored')
52
+ elif db_path.endswith('h5'):
53
+ self.db_type = 'h5'
54
+ self.loader = lambda x: np.array(x).astype('float32')
55
+ else:
56
+ self.db_type = 'dir'
57
+
58
+ self.in_memory = in_memory
59
+ if self.in_memory:
60
+ self.features = {}
61
+
62
+ def get(self, key):
63
+
64
+ if self.in_memory and key in self.features:
65
+ # We save f_input because we want to save the
66
+ # compressed bytes to save memory
67
+ f_input = self.features[key]
68
+ elif self.db_type == 'lmdb':
69
+ f_input = self.lmdb[key]
70
+ elif self.db_type == 'pth':
71
+ f_input = self.feat_file[key]
72
+ elif self.db_type == 'h5':
73
+ f_input = h5py.File(self.db_path, 'r')[key]
74
+ else:
75
+ f_input = open(os.path.join(self.db_path, key + self.ext), 'rb').read()
76
+
77
+ if self.in_memory and key not in self.features:
78
+ self.features[key] = f_input
79
+
80
+ # load image
81
+ feat = self.loader(f_input)
82
+
83
+ return feat
84
+
85
+ class Dataset(data.Dataset):
86
+
87
+ def get_vocab_size(self):
88
+ return self.vocab_size
89
+
90
+ def get_vocab(self):
91
+ return self.ix_to_word
92
+
93
+ def get_seq_length(self):
94
+ return self.seq_length
95
+
96
+ def __init__(self, opt):
97
+ self.opt = opt
98
+ self.seq_per_img = opt.seq_per_img
99
+
100
+ # feature related options
101
+ self.use_fc = getattr(opt, 'use_fc', True)
102
+ self.use_att = getattr(opt, 'use_att', True)
103
+ self.use_box = getattr(opt, 'use_box', 0)
104
+ self.norm_att_feat = getattr(opt, 'norm_att_feat', 0)
105
+ self.norm_box_feat = getattr(opt, 'norm_box_feat', 0)
106
+
107
+ # load the json file which contains additional information about the dataset
108
+ print('DataLoader loading json file: ', opt.input_json)
109
+ self.info = json.load(open(self.opt.input_json))
110
+ if 'ix_to_word' in self.info:
111
+ self.ix_to_word = self.info['ix_to_word']
112
+ self.vocab_size = len(self.ix_to_word)
113
+ print('vocab size is ', self.vocab_size)
114
+
115
+ # open the hdf5 file
116
+ print('DataLoader loading h5 file: ', opt.input_fc_dir, opt.input_att_dir, opt.input_box_dir, opt.input_label_h5)
117
+ """
118
+ Setting input_label_h5 to none is used when only doing generation.
119
+ For example, when you need to test on coco test set.
120
+ """
121
+ if self.opt.input_label_h5 != 'none':
122
+ self.h5_label_file = h5py.File(self.opt.input_label_h5, 'r', driver='core')
123
+ # load in the sequence data
124
+ seq_size = self.h5_label_file['labels'].shape
125
+ self.label = self.h5_label_file['labels'][:]
126
+ self.seq_length = seq_size[1]
127
+ print('max sequence length in data is', self.seq_length)
128
+ # load the pointers in full to RAM (should be small enough)
129
+ self.label_start_ix = self.h5_label_file['label_start_ix'][:]
130
+ self.label_end_ix = self.h5_label_file['label_end_ix'][:]
131
+ else:
132
+ self.seq_length = 1
133
+
134
+ self.data_in_memory = getattr(opt, 'data_in_memory', False)
135
+ self.fc_loader = HybridLoader(self.opt.input_fc_dir, '.npy', in_memory=self.data_in_memory)
136
+ self.att_loader = HybridLoader(self.opt.input_att_dir, '.npz', in_memory=self.data_in_memory)
137
+ self.box_loader = HybridLoader(self.opt.input_box_dir, '.npy', in_memory=self.data_in_memory)
138
+
139
+ self.num_images = len(self.info['images']) # self.label_start_ix.shape[0]
140
+ print('read %d image features' %(self.num_images))
141
+
142
+ # separate out indexes for each of the provided splits
143
+ self.split_ix = {'train': [], 'val': [], 'test': []}
144
+ for ix in range(len(self.info['images'])):
145
+ img = self.info['images'][ix]
146
+ if not 'split' in img:
147
+ self.split_ix['train'].append(ix)
148
+ self.split_ix['val'].append(ix)
149
+ self.split_ix['test'].append(ix)
150
+ elif img['split'] == 'train':
151
+ self.split_ix['train'].append(ix)
152
+ elif img['split'] == 'val':
153
+ self.split_ix['val'].append(ix)
154
+ elif img['split'] == 'test':
155
+ self.split_ix['test'].append(ix)
156
+ elif opt.train_only == 0: # restval
157
+ self.split_ix['train'].append(ix)
158
+
159
+ print('assigned %d images to split train' %len(self.split_ix['train']))
160
+ print('assigned %d images to split val' %len(self.split_ix['val']))
161
+ print('assigned %d images to split test' %len(self.split_ix['test']))
162
+
163
+ def get_captions(self, ix, seq_per_img):
164
+ # fetch the sequence labels
165
+ ix1 = self.label_start_ix[ix] - 1 #label_start_ix starts from 1
166
+ ix2 = self.label_end_ix[ix] - 1
167
+ ncap = ix2 - ix1 + 1 # number of captions available for this image
168
+ assert ncap > 0, 'an image does not have any label. this can be handled but right now isn\'t'
169
+
170
+ if ncap < seq_per_img:
171
+ # we need to subsample (with replacement)
172
+ seq = np.zeros([seq_per_img, self.seq_length], dtype = 'int')
173
+ for q in range(seq_per_img):
174
+ ixl = random.randint(ix1,ix2)
175
+ seq[q, :] = self.label[ixl, :self.seq_length]
176
+ else:
177
+ ixl = random.randint(ix1, ix2 - seq_per_img + 1)
178
+ seq = self.label[ixl: ixl + seq_per_img, :self.seq_length]
179
+
180
+ return seq
181
+
182
+ def collate_func(self, batch, split):
183
+ seq_per_img = self.seq_per_img
184
+
185
+ fc_batch = []
186
+ att_batch = []
187
+ label_batch = []
188
+
189
+ wrapped = False
190
+
191
+ infos = []
192
+ gts = []
193
+
194
+ for sample in batch:
195
+ # fetch image
196
+ tmp_fc, tmp_att, tmp_seq, \
197
+ ix, it_pos_now, tmp_wrapped = sample
198
+ if tmp_wrapped:
199
+ wrapped = True
200
+
201
+ fc_batch.append(tmp_fc)
202
+ att_batch.append(tmp_att)
203
+
204
+ tmp_label = np.zeros([seq_per_img, self.seq_length + 2], dtype = 'int')
205
+ if hasattr(self, 'h5_label_file'):
206
+ # if there is ground truth
207
+ tmp_label[:, 1 : self.seq_length + 1] = tmp_seq
208
+ label_batch.append(tmp_label)
209
+
210
+ # Used for reward evaluation
211
+ if hasattr(self, 'h5_label_file'):
212
+ # if there is ground truth
213
+ gts.append(self.label[self.label_start_ix[ix] - 1: self.label_end_ix[ix]])
214
+ else:
215
+ gts.append([])
216
+
217
+ # record associated info as well
218
+ info_dict = {}
219
+ info_dict['ix'] = ix
220
+ info_dict['id'] = self.info['images'][ix]['id']
221
+ info_dict['file_path'] = self.info['images'][ix].get('file_path', '')
222
+ infos.append(info_dict)
223
+
224
+ # #sort by att_feat length
225
+ # fc_batch, att_batch, label_batch, gts, infos = \
226
+ # zip(*sorted(zip(fc_batch, att_batch, np.vsplit(label_batch, batch_size), gts, infos), key=lambda x: len(x[1]), reverse=True))
227
+ fc_batch, att_batch, label_batch, gts, infos = \
228
+ zip(*sorted(zip(fc_batch, att_batch, label_batch, gts, infos), key=lambda x: 0, reverse=True))
229
+ data = {}
230
+ data['fc_feats'] = np.stack(fc_batch)
231
+ # merge att_feats
232
+ max_att_len = max([_.shape[0] for _ in att_batch])
233
+ data['att_feats'] = np.zeros([len(att_batch), max_att_len, att_batch[0].shape[1]], dtype = 'float32')
234
+ for i in range(len(att_batch)):
235
+ data['att_feats'][i, :att_batch[i].shape[0]] = att_batch[i]
236
+ data['att_masks'] = np.zeros(data['att_feats'].shape[:2], dtype='float32')
237
+ for i in range(len(att_batch)):
238
+ data['att_masks'][i, :att_batch[i].shape[0]] = 1
239
+ # set att_masks to None if attention features have same length
240
+ if data['att_masks'].sum() == data['att_masks'].size:
241
+ data['att_masks'] = None
242
+
243
+ data['labels'] = np.vstack(label_batch)
244
+ # generate mask
245
+ nonzeros = np.array(list(map(lambda x: (x != 0).sum()+2, data['labels'])))
246
+ mask_batch = np.zeros([data['labels'].shape[0], self.seq_length + 2], dtype = 'float32')
247
+ for ix, row in enumerate(mask_batch):
248
+ row[:nonzeros[ix]] = 1
249
+ data['masks'] = mask_batch
250
+ data['labels'] = data['labels'].reshape(len(batch), seq_per_img, -1)
251
+ data['masks'] = data['masks'].reshape(len(batch), seq_per_img, -1)
252
+
253
+ data['gts'] = gts # all ground truth captions of each images
254
+ data['bounds'] = {'it_pos_now': it_pos_now, # the it_pos_now of the last sample
255
+ 'it_max': len(self.split_ix[split]), 'wrapped': wrapped}
256
+ data['infos'] = infos
257
+
258
+ data = {k:torch.from_numpy(v) if type(v) is np.ndarray else v for k,v in data.items()} # Turn all ndarray to torch tensor
259
+
260
+ return data
261
+
262
+ def __getitem__(self, index):
263
+ """This function returns a tuple that is further passed to collate_fn
264
+ """
265
+ ix, it_pos_now, wrapped = index #self.split_ix[index]
266
+ if self.use_att:
267
+ att_feat = self.att_loader.get(str(self.info['images'][ix]['id']))
268
+ # Reshape to K x C
269
+ att_feat = att_feat.reshape(-1, att_feat.shape[-1])
270
+ if self.norm_att_feat:
271
+ att_feat = att_feat / np.linalg.norm(att_feat, 2, 1, keepdims=True)
272
+ if self.use_box:
273
+ box_feat = self.box_loader.get(str(self.info['images'][ix]['id']))
274
+ # devided by image width and height
275
+ x1,y1,x2,y2 = np.hsplit(box_feat, 4)
276
+ h,w = self.info['images'][ix]['height'], self.info['images'][ix]['width']
277
+ box_feat = np.hstack((x1/w, y1/h, x2/w, y2/h, (x2-x1)*(y2-y1)/(w*h))) # question? x2-x1+1??
278
+ if self.norm_box_feat:
279
+ box_feat = box_feat / np.linalg.norm(box_feat, 2, 1, keepdims=True)
280
+ att_feat = np.hstack([att_feat, box_feat])
281
+ # sort the features by the size of boxes
282
+ att_feat = np.stack(sorted(att_feat, key=lambda x:x[-1], reverse=True))
283
+ else:
284
+ att_feat = np.zeros((0,0), dtype='float32')
285
+ if self.use_fc:
286
+ try:
287
+ fc_feat = self.fc_loader.get(str(self.info['images'][ix]['id']))
288
+ except:
289
+ # Use average of attention when there is no fc provided (For bottomup feature)
290
+ fc_feat = att_feat.mean(0)
291
+ else:
292
+ fc_feat = np.zeros((0), dtype='float32')
293
+ if hasattr(self, 'h5_label_file'):
294
+ seq = self.get_captions(ix, self.seq_per_img)
295
+ else:
296
+ seq = None
297
+ return (fc_feat,
298
+ att_feat, seq,
299
+ ix, it_pos_now, wrapped)
300
+
301
+ def __len__(self):
302
+ return len(self.info['images'])
303
+
304
+ class DataLoader:
305
+ def __init__(self, opt):
306
+ self.opt = opt
307
+ self.batch_size = self.opt.batch_size
308
+ self.dataset = Dataset(opt)
309
+
310
+ # Initialize loaders and iters
311
+ self.loaders, self.iters = {}, {}
312
+ for split in ['train', 'val', 'test']:
313
+ if split == 'train':
314
+ sampler = MySampler(self.dataset.split_ix[split], shuffle=True, wrap=True)
315
+ else:
316
+ sampler = MySampler(self.dataset.split_ix[split], shuffle=False, wrap=False)
317
+ self.loaders[split] = data.DataLoader(dataset=self.dataset,
318
+ batch_size=self.batch_size,
319
+ sampler=sampler,
320
+ pin_memory=True,
321
+ num_workers=4, # 4 is usually enough
322
+ collate_fn=partial(self.dataset.collate_func, split=split),
323
+ drop_last=False)
324
+ self.iters[split] = iter(self.loaders[split])
325
+
326
+ def get_batch(self, split):
327
+ try:
328
+ data = next(self.iters[split])
329
+ except StopIteration:
330
+ self.iters[split] = iter(self.loaders[split])
331
+ data = next(self.iters[split])
332
+ return data
333
+
334
+ def reset_iterator(self, split):
335
+ self.loaders[split].sampler._reset_iter()
336
+ self.iters[split] = iter(self.loaders[split])
337
+
338
+ def get_vocab_size(self):
339
+ return self.dataset.get_vocab_size()
340
+
341
+ @property
342
+ def vocab_size(self):
343
+ return self.get_vocab_size()
344
+
345
+ def get_vocab(self):
346
+ return self.dataset.get_vocab()
347
+
348
+ def get_seq_length(self):
349
+ return self.dataset.get_seq_length()
350
+
351
+ @property
352
+ def seq_length(self):
353
+ return self.get_seq_length()
354
+
355
+ def state_dict(self):
356
+ def get_prefetch_num(split):
357
+ if self.loaders[split].num_workers > 0:
358
+ return (self.iters[split]._send_idx - self.iters[split]._rcvd_idx) * self.batch_size
359
+ else:
360
+ return 0
361
+ return {split: loader.sampler.state_dict(get_prefetch_num(split)) \
362
+ for split, loader in self.loaders.items()}
363
+
364
+ def load_state_dict(self, state_dict=None):
365
+ if state_dict is None:
366
+ return
367
+ for split in self.loaders.keys():
368
+ self.loaders[split].sampler.load_state_dict(state_dict[split])
369
+
370
+
371
+ class MySampler(data.sampler.Sampler):
372
+ def __init__(self, index_list, shuffle, wrap):
373
+ self.index_list = index_list
374
+ self.shuffle = shuffle
375
+ self.wrap = wrap
376
+ # if wrap, there will be not stop iteration called
377
+ # wrap True used during training, and wrap False used during test.
378
+ self._reset_iter()
379
+
380
+ def __iter__(self):
381
+ return self
382
+
383
+ def __next__(self):
384
+ wrapped = False
385
+ if self.iter_counter == len(self._index_list):
386
+ self._reset_iter()
387
+ if self.wrap:
388
+ wrapped = True
389
+ else:
390
+ raise StopIteration()
391
+ if len(self._index_list) == 0: # overflow when 0 samples
392
+ return None
393
+ elem = (self._index_list[self.iter_counter], self.iter_counter+1, wrapped)
394
+ self.iter_counter += 1
395
+ return elem
396
+
397
+ def next(self):
398
+ return self.__next__()
399
+
400
+ def _reset_iter(self):
401
+ if self.shuffle:
402
+ rand_perm = npr.permutation(len(self.index_list))
403
+ self._index_list = [self.index_list[_] for _ in rand_perm]
404
+ else:
405
+ self._index_list = self.index_list
406
+
407
+ self.iter_counter = 0
408
+
409
+ def __len__(self):
410
+ return len(self.index_list)
411
+
412
+ def load_state_dict(self, state_dict=None):
413
+ if state_dict is None:
414
+ return
415
+ self._index_list = state_dict['index_list']
416
+ self.iter_counter = state_dict['iter_counter']
417
+
418
+ def state_dict(self, prefetched_num=None):
419
+ prefetched_num = prefetched_num or 0
420
+ return {
421
+ 'index_list': self._index_list,
422
+ 'iter_counter': self.iter_counter - prefetched_num
423
+ }
424
+
425
+
captioning/data/pth_loader.py ADDED
@@ -0,0 +1,334 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import absolute_import
2
+ from __future__ import division
3
+ from __future__ import print_function
4
+
5
+ import json
6
+ import h5py
7
+ from lmdbdict import lmdbdict
8
+ from lmdbdict.methods import DUMPS_FUNC, LOADS_FUNC
9
+ import os
10
+ import numpy as np
11
+ import numpy.random as npr
12
+ import random
13
+
14
+ import torch
15
+ import torch.utils.data as data
16
+
17
+ import multiprocessing
18
+ import six
19
+
20
+ verbose = True
21
+ # import torch
22
+ # if torch.cuda.current_device() in [0, -1]:
23
+ if 'LOCAL_RANK' in os.environ and os.environ['LOCAL_RANK'] != '0':
24
+ verbose = False
25
+
26
+ class HybridLoader:
27
+ """
28
+ If db_path is a director, then use normal file loading
29
+ If lmdb, then load from lmdb
30
+ The loading method depend on extention.
31
+
32
+ in_memory: if in_memory is True, we save all the features in memory
33
+ For individual np(y|z)s, we don't need to do that because the system will do this for us.
34
+ Should be useful for lmdb or h5.
35
+ (Copied this idea from vilbert)
36
+ """
37
+ def __init__(self, db_path, ext, in_memory=False):
38
+ self.db_path = db_path
39
+ self.ext = ext
40
+ if self.ext == '.npy':
41
+ self.loader = lambda x: np.load(six.BytesIO(x))
42
+ else:
43
+ self.loader = lambda x: np.load(six.BytesIO(x))['feat']
44
+ if db_path.endswith('.lmdb'):
45
+ self.db_type = 'lmdb'
46
+ self.lmdb = lmdbdict(db_path, unsafe=True)
47
+ self.lmdb._key_dumps = DUMPS_FUNC['ascii']
48
+ self.lmdb._value_loads = LOADS_FUNC['identity']
49
+ elif db_path.endswith('.pth'): # Assume a key,value dictionary
50
+ self.db_type = 'pth'
51
+ self.feat_file = torch.load(db_path)
52
+ self.loader = lambda x: x
53
+ print('HybridLoader: ext is ignored')
54
+ elif db_path.endswith('h5'):
55
+ self.db_type = 'h5'
56
+ self.loader = lambda x: np.array(x).astype('float32')
57
+ else:
58
+ self.db_type = 'dir'
59
+
60
+ self.in_memory = in_memory
61
+ if self.in_memory:
62
+ self.features = {}
63
+
64
+ def get(self, key):
65
+
66
+ if self.in_memory and key in self.features:
67
+ # We save f_input because we want to save the
68
+ # compressed bytes to save memory
69
+ f_input = self.features[key]
70
+ elif self.db_type == 'lmdb':
71
+ f_input = self.lmdb[key]
72
+ elif self.db_type == 'pth':
73
+ f_input = self.feat_file[key]
74
+ elif self.db_type == 'h5':
75
+ f_input = h5py.File(self.db_path, 'r')[key]
76
+ else:
77
+ f_input = open(os.path.join(self.db_path, key + self.ext), 'rb').read()
78
+
79
+ if self.in_memory and key not in self.features:
80
+ self.features[key] = f_input
81
+
82
+ # load image
83
+ feat = self.loader(f_input)
84
+
85
+ return feat
86
+
87
+ class CaptionDataset(data.Dataset):
88
+
89
+ def get_vocab_size(self):
90
+ return self.vocab_size
91
+
92
+ def get_vocab(self):
93
+ return self.ix_to_word
94
+
95
+ def get_seq_length(self):
96
+ return self.seq_length
97
+
98
+ def __init__(self, opt):
99
+ self.opt = opt
100
+ self.seq_per_img = opt.seq_per_img
101
+
102
+ # feature related options
103
+ self.use_fc = getattr(opt, 'use_fc', True)
104
+ self.use_att = getattr(opt, 'use_att', True)
105
+ self.use_box = getattr(opt, 'use_box', 0)
106
+ self.norm_att_feat = getattr(opt, 'norm_att_feat', 0)
107
+ self.norm_box_feat = getattr(opt, 'norm_box_feat', 0)
108
+
109
+ # load the json file which contains additional information about the dataset
110
+ if verbose:
111
+ print('DataLoader loading json file: ', opt.input_json)
112
+ self.info = json.load(open(self.opt.input_json))
113
+ if 'ix_to_word' in self.info:
114
+ self.ix_to_word = self.info['ix_to_word']
115
+ self.vocab_size = len(self.ix_to_word)
116
+ if verbose:
117
+ print('vocab size is ', self.vocab_size)
118
+
119
+ # open the hdf5 file
120
+ if verbose:
121
+ print('DataLoader loading h5 file: ', opt.input_fc_dir, opt.input_att_dir, opt.input_box_dir, opt.input_label_h5)
122
+ """
123
+ Setting input_label_h5 to none is used when only doing generation.
124
+ For example, when you need to test on coco test set.
125
+ """
126
+ if self.opt.input_label_h5 != 'none':
127
+ self.h5_label_file = h5py.File(self.opt.input_label_h5, 'r', driver='core')
128
+ # load in the sequence data
129
+ seq_size = self.h5_label_file['labels'].shape
130
+ self.label = self.h5_label_file['labels'][:]
131
+ self.seq_length = seq_size[1]
132
+ if verbose:
133
+ print('max sequence length in data is', self.seq_length)
134
+ # load the pointers in full to RAM (should be small enough)
135
+ self.label_start_ix = self.h5_label_file['label_start_ix'][:]
136
+ self.label_end_ix = self.h5_label_file['label_end_ix'][:]
137
+ else:
138
+ self.seq_length = 1
139
+
140
+ self.data_in_memory = getattr(opt, 'data_in_memory', False)
141
+ self.fc_loader = HybridLoader(self.opt.input_fc_dir, '.npy', in_memory=self.data_in_memory)
142
+ self.att_loader = HybridLoader(self.opt.input_att_dir, '.npz', in_memory=self.data_in_memory)
143
+ self.box_loader = HybridLoader(self.opt.input_box_dir, '.npy', in_memory=self.data_in_memory)
144
+
145
+ self.use_clipscore = getattr(opt, 'use_clipscore', False)
146
+ # if self.use_clipscore:
147
+ self.clipscore_loader = HybridLoader(self.opt.input_clipscore_vis_dir, '.npy', in_memory=self.data_in_memory)
148
+
149
+
150
+ self.num_images = len(self.info['images']) # self.label_start_ix.shape[0]
151
+ if verbose:
152
+ print('read %d image features' %(self.num_images))
153
+
154
+ # separate out indexes for each of the provided splits
155
+ self.split_ix = {'train': [], 'val': [], 'test': []}
156
+ for ix in range(len(self.info['images'])):
157
+ img = self.info['images'][ix]
158
+ if not 'split' in img:
159
+ self.split_ix['train'].append(ix)
160
+ self.split_ix['val'].append(ix)
161
+ self.split_ix['test'].append(ix)
162
+ elif img['split'] == 'train':
163
+ self.split_ix['train'].append(ix)
164
+ elif img['split'] == 'val':
165
+ self.split_ix['val'].append(ix)
166
+ elif img['split'] == 'test':
167
+ self.split_ix['test'].append(ix)
168
+ elif opt.train_only == 0: # restval
169
+ self.split_ix['train'].append(ix)
170
+
171
+ if verbose:
172
+ print('assigned %d images to split train' %len(self.split_ix['train']))
173
+ print('assigned %d images to split val' %len(self.split_ix['val']))
174
+ print('assigned %d images to split test' %len(self.split_ix['test']))
175
+
176
+ def get_captions(self, ix, seq_per_img):
177
+ # fetch the sequence labels
178
+ ix1 = self.label_start_ix[ix] - 1 #label_start_ix starts from 1
179
+ ix2 = self.label_end_ix[ix] - 1
180
+ ncap = ix2 - ix1 + 1 # number of captions available for this image
181
+ assert ncap > 0, 'an image does not have any label. this can be handled but right now isn\'t'
182
+
183
+ if ncap < seq_per_img:
184
+ # we need to subsample (with replacement)
185
+ seq = np.zeros([seq_per_img, self.seq_length], dtype = 'int')
186
+ for q in range(seq_per_img):
187
+ ixl = random.randint(ix1,ix2)
188
+ seq[q, :] = self.label[ixl, :self.seq_length]
189
+ else:
190
+ ixl = random.randint(ix1, ix2 - seq_per_img + 1)
191
+ seq = self.label[ixl: ixl + seq_per_img, :self.seq_length]
192
+
193
+ return seq
194
+
195
+ def collate_func(self, batch):
196
+ seq_per_img = self.seq_per_img
197
+
198
+ fc_batch = []
199
+ att_batch = []
200
+ label_batch = []
201
+
202
+ clip_vis_feat_batch = []
203
+
204
+ wrapped = False
205
+
206
+ infos = []
207
+ gts = []
208
+
209
+ for sample in batch:
210
+ # fetch image
211
+ # if self.use_clipscore:
212
+ tmp_fc, tmp_att, tmp_seq, \
213
+ ix, tmp_clip_vis_feat = sample
214
+
215
+ clip_vis_feat_batch.append(tmp_clip_vis_feat)
216
+ # else:
217
+ # tmp_fc, tmp_att, tmp_seq, \
218
+ # ix = sample
219
+
220
+ fc_batch.append(tmp_fc)
221
+ att_batch.append(tmp_att)
222
+
223
+ tmp_label = np.zeros([seq_per_img, self.seq_length + 2], dtype = 'int')
224
+ if hasattr(self, 'h5_label_file'):
225
+ # if there is ground truth
226
+ tmp_label[:, 1 : self.seq_length + 1] = tmp_seq
227
+ label_batch.append(tmp_label)
228
+
229
+ # Used for reward evaluation
230
+ if hasattr(self, 'h5_label_file'):
231
+ # if there is ground truth
232
+ gts.append(self.label[self.label_start_ix[ix] - 1: self.label_end_ix[ix]])
233
+ else:
234
+ gts.append([])
235
+
236
+ # record associated info as well
237
+ info_dict = {}
238
+ info_dict['ix'] = ix
239
+ info_dict['id'] = self.info['images'][ix]['id']
240
+ info_dict['file_path'] = self.info['images'][ix].get('file_path', '')
241
+ infos.append(info_dict)
242
+
243
+ # #sort by att_feat length
244
+ # fc_batch, att_batch, label_batch, gts, infos = \
245
+ # zip(*sorted(zip(fc_batch, att_batch, np.vsplit(label_batch, batch_size), gts, infos), key=lambda x: len(x[1]), reverse=True))
246
+ if self.use_clipscore:
247
+ fc_batch, att_batch, label_batch, clip_vis_feat_batch, gts, infos = \
248
+ zip(*sorted(zip(fc_batch, att_batch, label_batch, clip_vis_feat_batch, gts, infos), key=lambda x: 0, reverse=True))
249
+ else:
250
+ fc_batch, att_batch, label_batch, gts, infos = \
251
+ zip(*sorted(zip(fc_batch, att_batch, label_batch, gts, infos), key=lambda x: 0, reverse=True))
252
+ data = {}
253
+ data['fc_feats'] = np.stack(fc_batch)
254
+ # merge att_feats
255
+ max_att_len = max([_.shape[0] for _ in att_batch])
256
+ data['att_feats'] = np.zeros([len(att_batch), max_att_len, att_batch[0].shape[1]], dtype = 'float32')
257
+ for i in range(len(att_batch)):
258
+ data['att_feats'][i, :att_batch[i].shape[0]] = att_batch[i]
259
+ data['att_masks'] = np.zeros(data['att_feats'].shape[:2], dtype='float32')
260
+ for i in range(len(att_batch)):
261
+ data['att_masks'][i, :att_batch[i].shape[0]] = 1
262
+ # set att_masks to None if attention features have same length
263
+ if data['att_masks'].sum() == data['att_masks'].size:
264
+ data['att_masks'] = None
265
+
266
+ # if self.use_clipscore:
267
+ data['clip_vis_feats'] = np.stack(clip_vis_feat_batch)
268
+
269
+ data['labels'] = np.vstack(label_batch)
270
+ # generate mask
271
+ nonzeros = np.array(list(map(lambda x: (x != 0).sum()+2, data['labels'])))
272
+ mask_batch = np.zeros([data['labels'].shape[0], self.seq_length + 2], dtype = 'float32')
273
+ for ix, row in enumerate(mask_batch):
274
+ row[:nonzeros[ix]] = 1
275
+ data['masks'] = mask_batch
276
+ data['labels'] = data['labels'].reshape(len(batch), seq_per_img, -1)
277
+ data['masks'] = data['masks'].reshape(len(batch), seq_per_img, -1)
278
+
279
+ data['gts'] = gts # all ground truth captions of each images
280
+ data['infos'] = infos
281
+
282
+ data = {k:torch.from_numpy(v) if type(v) is np.ndarray else v for k,v in data.items()} # Turn all ndarray to torch tensor
283
+
284
+ return data
285
+
286
+ def __getitem__(self, ix):
287
+ """This function returns a tuple that is further passed to collate_fn
288
+ """
289
+ if self.use_att:
290
+ att_feat = self.att_loader.get(str(self.info['images'][ix]['id']))
291
+ # Reshape to K x C
292
+ att_feat = att_feat.reshape(-1, att_feat.shape[-1])
293
+ if self.norm_att_feat:
294
+ att_feat = att_feat / np.linalg.norm(att_feat, 2, 1, keepdims=True)
295
+ if self.use_box:
296
+ box_feat = self.box_loader.get(str(self.info['images'][ix]['id']))
297
+ # devided by image width and height
298
+ x1,y1,x2,y2 = np.hsplit(box_feat, 4)
299
+ h,w = self.info['images'][ix]['height'], self.info['images'][ix]['width']
300
+ box_feat = np.hstack((x1/w, y1/h, x2/w, y2/h, (x2-x1)*(y2-y1)/(w*h))) # question? x2-x1+1??
301
+ if self.norm_box_feat:
302
+ box_feat = box_feat / np.linalg.norm(box_feat, 2, 1, keepdims=True)
303
+ att_feat = np.hstack([att_feat, box_feat])
304
+ # sort the features by the size of boxes
305
+ att_feat = np.stack(sorted(att_feat, key=lambda x:x[-1], reverse=True))
306
+ else:
307
+ att_feat = np.zeros((0,0), dtype='float32')
308
+ if self.use_fc:
309
+ try:
310
+ fc_feat = self.fc_loader.get(str(self.info['images'][ix]['id']))
311
+ except:
312
+ # Use average of attention when there is no fc provided (For bottomup feature)
313
+ fc_feat = att_feat.mean(0)
314
+ else:
315
+ fc_feat = np.zeros((0), dtype='float32')
316
+ if hasattr(self, 'h5_label_file'):
317
+ seq = self.get_captions(ix, self.seq_per_img)
318
+ else:
319
+ seq = None
320
+
321
+ # if self.use_clipscore:
322
+ clip_vis_feat = self.clipscore_loader.get(
323
+ str(self.info['images'][ix]['id']))
324
+
325
+ return (fc_feat,
326
+ att_feat, seq,
327
+ ix, clip_vis_feat)
328
+
329
+ # return (fc_feat,
330
+ # att_feat, seq,
331
+ # ix)
332
+
333
+ def __len__(self):
334
+ return len(self.info['images'])
captioning/data/pth_loader_FineCapEval.py ADDED
@@ -0,0 +1,334 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import absolute_import
2
+ from __future__ import division
3
+ from __future__ import print_function
4
+
5
+ import json
6
+ import h5py
7
+ from lmdbdict import lmdbdict
8
+ from lmdbdict.methods import DUMPS_FUNC, LOADS_FUNC
9
+ import os
10
+ import numpy as np
11
+ import numpy.random as npr
12
+ import random
13
+
14
+ import torch
15
+ import torch.utils.data as data
16
+
17
+ import multiprocessing
18
+ import six
19
+
20
+ verbose = True
21
+ # import torch
22
+ # if torch.cuda.current_device() in [0, -1]:
23
+ if 'LOCAL_RANK' in os.environ and os.environ['LOCAL_RANK'] != '0':
24
+ verbose = False
25
+
26
+ class HybridLoader:
27
+ """
28
+ If db_path is a director, then use normal file loading
29
+ If lmdb, then load from lmdb
30
+ The loading method depend on extention.
31
+
32
+ in_memory: if in_memory is True, we save all the features in memory
33
+ For individual np(y|z)s, we don't need to do that because the system will do this for us.
34
+ Should be useful for lmdb or h5.
35
+ (Copied this idea from vilbert)
36
+ """
37
+ def __init__(self, db_path, ext, in_memory=False):
38
+ self.db_path = db_path
39
+ self.ext = ext
40
+ if self.ext == '.npy':
41
+ self.loader = lambda x: np.load(six.BytesIO(x))
42
+ else:
43
+ self.loader = lambda x: np.load(six.BytesIO(x))['feat']
44
+ if db_path.endswith('.lmdb'):
45
+ self.db_type = 'lmdb'
46
+ self.lmdb = lmdbdict(db_path, unsafe=True)
47
+ self.lmdb._key_dumps = DUMPS_FUNC['ascii']
48
+ self.lmdb._value_loads = LOADS_FUNC['identity']
49
+ elif db_path.endswith('.pth'): # Assume a key,value dictionary
50
+ self.db_type = 'pth'
51
+ self.feat_file = torch.load(db_path)
52
+ self.loader = lambda x: x
53
+ print('HybridLoader: ext is ignored')
54
+ elif db_path.endswith('h5'):
55
+ self.db_type = 'h5'
56
+ self.loader = lambda x: np.array(x).astype('float32')
57
+ else:
58
+ self.db_type = 'dir'
59
+
60
+ self.in_memory = in_memory
61
+ if self.in_memory:
62
+ self.features = {}
63
+
64
+ def get(self, key):
65
+
66
+ if self.in_memory and key in self.features:
67
+ # We save f_input because we want to save the
68
+ # compressed bytes to save memory
69
+ f_input = self.features[key]
70
+ elif self.db_type == 'lmdb':
71
+ f_input = self.lmdb[key]
72
+ elif self.db_type == 'pth':
73
+ f_input = self.feat_file[key]
74
+ elif self.db_type == 'h5':
75
+ f_input = h5py.File(self.db_path, 'r')[key]
76
+ else:
77
+ f_input = open(os.path.join(self.db_path, key + self.ext), 'rb').read()
78
+
79
+ if self.in_memory and key not in self.features:
80
+ self.features[key] = f_input
81
+
82
+ # load image
83
+ feat = self.loader(f_input)
84
+
85
+ return feat
86
+
87
+ class CaptionDataset(data.Dataset):
88
+
89
+ def get_vocab_size(self):
90
+ return self.vocab_size
91
+
92
+ def get_vocab(self):
93
+ return self.ix_to_word
94
+
95
+ def get_seq_length(self):
96
+ return self.seq_length
97
+
98
+ def __init__(self, opt):
99
+ self.opt = opt
100
+ self.seq_per_img = opt.seq_per_img
101
+
102
+ # feature related options
103
+ self.use_fc = getattr(opt, 'use_fc', True)
104
+ self.use_att = getattr(opt, 'use_att', True)
105
+ self.use_box = getattr(opt, 'use_box', 0)
106
+ self.norm_att_feat = getattr(opt, 'norm_att_feat', 0)
107
+ self.norm_box_feat = getattr(opt, 'norm_box_feat', 0)
108
+
109
+ # load the json file which contains additional information about the dataset
110
+ if verbose:
111
+ print('DataLoader loading json file: ', opt.input_json)
112
+ self.info = json.load(open(self.opt.input_json))
113
+ if 'ix_to_word' in self.info:
114
+ self.ix_to_word = self.info['ix_to_word']
115
+ self.vocab_size = len(self.ix_to_word)
116
+ if verbose:
117
+ print('vocab size is ', self.vocab_size)
118
+
119
+ # open the hdf5 file
120
+ if verbose:
121
+ print('DataLoader loading h5 file: ', opt.input_fc_dir, opt.input_att_dir, opt.input_box_dir, opt.input_label_h5)
122
+ """
123
+ Setting input_label_h5 to none is used when only doing generation.
124
+ For example, when you need to test on coco test set.
125
+ """
126
+ if self.opt.input_label_h5 != 'none':
127
+ self.h5_label_file = h5py.File(self.opt.input_label_h5, 'r', driver='core')
128
+ # load in the sequence data
129
+ seq_size = self.h5_label_file['labels'].shape
130
+ self.label = self.h5_label_file['labels'][:]
131
+ self.seq_length = seq_size[1]
132
+ if verbose:
133
+ print('max sequence length in data is', self.seq_length)
134
+ # load the pointers in full to RAM (should be small enough)
135
+ self.label_start_ix = self.h5_label_file['label_start_ix'][:]
136
+ self.label_end_ix = self.h5_label_file['label_end_ix'][:]
137
+ else:
138
+ self.seq_length = 1
139
+
140
+ self.data_in_memory = getattr(opt, 'data_in_memory', False)
141
+ self.fc_loader = HybridLoader(self.opt.input_fc_dir, '.npy', in_memory=self.data_in_memory)
142
+ self.att_loader = HybridLoader(self.opt.input_att_dir, '.npz', in_memory=self.data_in_memory)
143
+ self.box_loader = HybridLoader(self.opt.input_box_dir, '.npy', in_memory=self.data_in_memory)
144
+
145
+ self.use_clipscore = getattr(opt, 'use_clipscore', False)
146
+ if self.use_clipscore:
147
+ self.clipscore_loader = HybridLoader(self.opt.input_clipscore_vis_dir, '.npy', in_memory=self.data_in_memory)
148
+
149
+
150
+ self.num_images = len(self.info['images']) # self.label_start_ix.shape[0]
151
+ if verbose:
152
+ print('read %d image features' %(self.num_images))
153
+
154
+ # separate out indexes for each of the provided splits
155
+ self.split_ix = {'train': [], 'val': [], 'test': []}
156
+ for ix in range(len(self.info['images'])):
157
+ img = self.info['images'][ix]
158
+ if not 'split' in img:
159
+ self.split_ix['train'].append(ix)
160
+ self.split_ix['val'].append(ix)
161
+ self.split_ix['test'].append(ix)
162
+ elif img['split'] == 'train':
163
+ self.split_ix['train'].append(ix)
164
+ elif img['split'] == 'val':
165
+ self.split_ix['val'].append(ix)
166
+ elif img['split'] == 'test':
167
+ self.split_ix['test'].append(ix)
168
+ elif opt.train_only == 0: # restval
169
+ self.split_ix['train'].append(ix)
170
+
171
+ if verbose:
172
+ print('assigned %d images to split train' %len(self.split_ix['train']))
173
+ print('assigned %d images to split val' %len(self.split_ix['val']))
174
+ print('assigned %d images to split test' %len(self.split_ix['test']))
175
+
176
+ def get_captions(self, ix, seq_per_img):
177
+ # fetch the sequence labels
178
+ ix1 = self.label_start_ix[ix] - 1 #label_start_ix starts from 1
179
+ ix2 = self.label_end_ix[ix] - 1
180
+ ncap = ix2 - ix1 + 1 # number of captions available for this image
181
+ assert ncap > 0, 'an image does not have any label. this can be handled but right now isn\'t'
182
+
183
+ if ncap < seq_per_img:
184
+ # we need to subsample (with replacement)
185
+ seq = np.zeros([seq_per_img, self.seq_length], dtype = 'int')
186
+ for q in range(seq_per_img):
187
+ ixl = random.randint(ix1,ix2)
188
+ seq[q, :] = self.label[ixl, :self.seq_length]
189
+ else:
190
+ ixl = random.randint(ix1, ix2 - seq_per_img + 1)
191
+ seq = self.label[ixl: ixl + seq_per_img, :self.seq_length]
192
+
193
+ return seq
194
+
195
+ def collate_func(self, batch):
196
+ seq_per_img = self.seq_per_img
197
+
198
+ fc_batch = []
199
+ att_batch = []
200
+ label_batch = []
201
+
202
+ clip_vis_feat_batch = []
203
+
204
+ wrapped = False
205
+
206
+ infos = []
207
+ gts = []
208
+
209
+ for sample in batch:
210
+ # fetch image
211
+ if self.use_clipscore:
212
+ tmp_fc, tmp_att, tmp_seq, \
213
+ ix, tmp_clip_vis_feat = sample
214
+
215
+ clip_vis_feat_batch.append(tmp_clip_vis_feat)
216
+ else:
217
+ tmp_fc, tmp_att, tmp_seq, \
218
+ ix = sample
219
+
220
+ fc_batch.append(tmp_fc)
221
+ att_batch.append(tmp_att)
222
+
223
+ tmp_label = np.zeros([seq_per_img, self.seq_length + 2], dtype = 'int')
224
+ if hasattr(self, 'h5_label_file'):
225
+ # if there is ground truth
226
+ tmp_label[:, 1 : self.seq_length + 1] = tmp_seq
227
+ label_batch.append(tmp_label)
228
+
229
+ # Used for reward evaluation
230
+ if hasattr(self, 'h5_label_file'):
231
+ # if there is ground truth
232
+ gts.append(self.label[self.label_start_ix[ix] - 1: self.label_end_ix[ix]])
233
+ else:
234
+ gts.append([])
235
+
236
+ # record associated info as well
237
+ info_dict = {}
238
+ info_dict['ix'] = ix
239
+ info_dict['id'] = self.info['images'][ix]['id']
240
+ info_dict['file_path'] = self.info['images'][ix].get('file_path', '')
241
+ infos.append(info_dict)
242
+
243
+ # #sort by att_feat length
244
+ # fc_batch, att_batch, label_batch, gts, infos = \
245
+ # zip(*sorted(zip(fc_batch, att_batch, np.vsplit(label_batch, batch_size), gts, infos), key=lambda x: len(x[1]), reverse=True))
246
+ if self.use_clipscore:
247
+ fc_batch, att_batch, label_batch, clip_vis_feat_batch, gts, infos = \
248
+ zip(*sorted(zip(fc_batch, att_batch, label_batch, clip_vis_feat_batch, gts, infos), key=lambda x: 0, reverse=True))
249
+ else:
250
+ fc_batch, att_batch, label_batch, gts, infos = \
251
+ zip(*sorted(zip(fc_batch, att_batch, label_batch, gts, infos), key=lambda x: 0, reverse=True))
252
+ data = {}
253
+ data['fc_feats'] = np.stack(fc_batch)
254
+ # merge att_feats
255
+ max_att_len = max([_.shape[0] for _ in att_batch])
256
+ data['att_feats'] = np.zeros([len(att_batch), max_att_len, att_batch[0].shape[1]], dtype = 'float32')
257
+ for i in range(len(att_batch)):
258
+ data['att_feats'][i, :att_batch[i].shape[0]] = att_batch[i]
259
+ data['att_masks'] = np.zeros(data['att_feats'].shape[:2], dtype='float32')
260
+ for i in range(len(att_batch)):
261
+ data['att_masks'][i, :att_batch[i].shape[0]] = 1
262
+ # set att_masks to None if attention features have same length
263
+ if data['att_masks'].sum() == data['att_masks'].size:
264
+ data['att_masks'] = None
265
+
266
+ if self.use_clipscore:
267
+ data['clip_vis_feats'] = np.stack(clip_vis_feat_batch)
268
+
269
+ data['labels'] = np.vstack(label_batch)
270
+ # generate mask
271
+ nonzeros = np.array(list(map(lambda x: (x != 0).sum()+2, data['labels'])))
272
+ mask_batch = np.zeros([data['labels'].shape[0], self.seq_length + 2], dtype = 'float32')
273
+ for ix, row in enumerate(mask_batch):
274
+ row[:nonzeros[ix]] = 1
275
+ data['masks'] = mask_batch
276
+ data['labels'] = data['labels'].reshape(len(batch), seq_per_img, -1)
277
+ data['masks'] = data['masks'].reshape(len(batch), seq_per_img, -1)
278
+
279
+ data['gts'] = gts # all ground truth captions of each images
280
+ data['infos'] = infos
281
+
282
+ data = {k:torch.from_numpy(v) if type(v) is np.ndarray else v for k,v in data.items()} # Turn all ndarray to torch tensor
283
+
284
+ return data
285
+
286
+ def __getitem__(self, ix):
287
+ """This function returns a tuple that is further passed to collate_fn
288
+ """
289
+ if self.use_att:
290
+ att_feat = self.att_loader.get(str(self.info['images'][ix]['id']))
291
+ # Reshape to K x C
292
+ att_feat = att_feat.reshape(-1, att_feat.shape[-1])
293
+ if self.norm_att_feat:
294
+ att_feat = att_feat / np.linalg.norm(att_feat, 2, 1, keepdims=True)
295
+ if self.use_box:
296
+ box_feat = self.box_loader.get(str(self.info['images'][ix]['id']))
297
+ # devided by image width and height
298
+ x1,y1,x2,y2 = np.hsplit(box_feat, 4)
299
+ h,w = self.info['images'][ix]['height'], self.info['images'][ix]['width']
300
+ box_feat = np.hstack((x1/w, y1/h, x2/w, y2/h, (x2-x1)*(y2-y1)/(w*h))) # question? x2-x1+1??
301
+ if self.norm_box_feat:
302
+ box_feat = box_feat / np.linalg.norm(box_feat, 2, 1, keepdims=True)
303
+ att_feat = np.hstack([att_feat, box_feat])
304
+ # sort the features by the size of boxes
305
+ att_feat = np.stack(sorted(att_feat, key=lambda x:x[-1], reverse=True))
306
+ else:
307
+ att_feat = np.zeros((0,0), dtype='float32')
308
+ if self.use_fc:
309
+ try:
310
+ fc_feat = self.fc_loader.get(str(self.info['images'][ix]['id']))
311
+ except:
312
+ # Use average of attention when there is no fc provided (For bottomup feature)
313
+ fc_feat = att_feat.mean(0)
314
+ else:
315
+ fc_feat = np.zeros((0), dtype='float32')
316
+ if hasattr(self, 'h5_label_file'):
317
+ seq = self.get_captions(ix, self.seq_per_img)
318
+ else:
319
+ seq = None
320
+
321
+ if self.use_clipscore:
322
+ clip_vis_feat = self.clipscore_loader.get(
323
+ str(self.info['images'][ix]['id']))
324
+
325
+ return (fc_feat,
326
+ att_feat, seq,
327
+ ix, clip_vis_feat)
328
+
329
+ return (fc_feat,
330
+ att_feat, seq,
331
+ ix)
332
+
333
+ def __len__(self):
334
+ return len(self.info['images'])
captioning/models/AoAModel.py ADDED
@@ -0,0 +1,228 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Implementation for paper 'Attention on Attention for Image Captioning'
2
+ # https://arxiv.org/abs/1908.06954
3
+
4
+ # RT: Code from original author's repo: https://github.com/husthuaan/AoANet/
5
+
6
+ from __future__ import absolute_import
7
+ from __future__ import division
8
+ from __future__ import print_function
9
+
10
+ import torch
11
+ import torch.nn as nn
12
+ import torch.nn.functional as F
13
+
14
+ from .AttModel import pack_wrapper, AttModel, Attention
15
+ from .TransformerModel import LayerNorm, attention, clones, SublayerConnection, PositionwiseFeedForward
16
+
17
+ class MultiHeadedDotAttention(nn.Module):
18
+ def __init__(self, h, d_model, dropout=0.1, scale=1, project_k_v=1, use_output_layer=1, do_aoa=0, norm_q=0, dropout_aoa=0.3):
19
+ super(MultiHeadedDotAttention, self).__init__()
20
+ assert d_model * scale % h == 0
21
+ # We assume d_v always equals d_k
22
+ self.d_k = d_model * scale // h
23
+ self.h = h
24
+
25
+ # Do we need to do linear projections on K and V?
26
+ self.project_k_v = project_k_v
27
+
28
+ # normalize the query?
29
+ if norm_q:
30
+ self.norm = LayerNorm(d_model)
31
+ else:
32
+ self.norm = lambda x:x
33
+ self.linears = clones(nn.Linear(d_model, d_model * scale), 1 + 2 * project_k_v)
34
+
35
+ # output linear layer after the multi-head attention?
36
+ self.output_layer = nn.Linear(d_model * scale, d_model)
37
+
38
+ # apply aoa after attention?
39
+ self.use_aoa = do_aoa
40
+ if self.use_aoa:
41
+ self.aoa_layer = nn.Sequential(nn.Linear((1 + scale) * d_model, 2 * d_model), nn.GLU())
42
+ # dropout to the input of AoA layer
43
+ if dropout_aoa > 0:
44
+ self.dropout_aoa = nn.Dropout(p=dropout_aoa)
45
+ else:
46
+ self.dropout_aoa = lambda x:x
47
+
48
+ if self.use_aoa or not use_output_layer:
49
+ # AoA doesn't need the output linear layer
50
+ del self.output_layer
51
+ self.output_layer = lambda x:x
52
+
53
+ self.attn = None
54
+ self.dropout = nn.Dropout(p=dropout)
55
+
56
+ def forward(self, query, value, key, mask=None):
57
+ if mask is not None:
58
+ if len(mask.size()) == 2:
59
+ mask = mask.unsqueeze(-2)
60
+ # Same mask applied to all h heads.
61
+ mask = mask.unsqueeze(1)
62
+
63
+ single_query = 0
64
+ if len(query.size()) == 2:
65
+ single_query = 1
66
+ query = query.unsqueeze(1)
67
+
68
+ nbatches = query.size(0)
69
+
70
+ query = self.norm(query)
71
+
72
+ # Do all the linear projections in batch from d_model => h x d_k
73
+ if self.project_k_v == 0:
74
+ query_ = self.linears[0](query).view(nbatches, -1, self.h, self.d_k).transpose(1, 2)
75
+ key_ = key.view(nbatches, -1, self.h, self.d_k).transpose(1, 2)
76
+ value_ = value.view(nbatches, -1, self.h, self.d_k).transpose(1, 2)
77
+ else:
78
+ query_, key_, value_ = \
79
+ [l(x).view(nbatches, -1, self.h, self.d_k).transpose(1, 2)
80
+ for l, x in zip(self.linears, (query, key, value))]
81
+
82
+ # Apply attention on all the projected vectors in batch.
83
+ x, self.attn = attention(query_, key_, value_, mask=mask,
84
+ dropout=self.dropout)
85
+
86
+ # "Concat" using a view
87
+ x = x.transpose(1, 2).contiguous() \
88
+ .view(nbatches, -1, self.h * self.d_k)
89
+
90
+ if self.use_aoa:
91
+ # Apply AoA
92
+ x = self.aoa_layer(self.dropout_aoa(torch.cat([x, query], -1)))
93
+ x = self.output_layer(x)
94
+
95
+ if single_query:
96
+ query = query.squeeze(1)
97
+ x = x.squeeze(1)
98
+ return x
99
+
100
+ class AoA_Refiner_Layer(nn.Module):
101
+ def __init__(self, size, self_attn, feed_forward, dropout):
102
+ super(AoA_Refiner_Layer, self).__init__()
103
+ self.self_attn = self_attn
104
+ self.feed_forward = feed_forward
105
+ self.use_ff = 0
106
+ if self.feed_forward is not None:
107
+ self.use_ff = 1
108
+ self.sublayer = clones(SublayerConnection(size, dropout), 1+self.use_ff)
109
+ self.size = size
110
+
111
+ def forward(self, x, mask):
112
+ x = self.sublayer[0](x, lambda x: self.self_attn(x, x, x, mask))
113
+ return self.sublayer[-1](x, self.feed_forward) if self.use_ff else x
114
+
115
+ class AoA_Refiner_Core(nn.Module):
116
+ def __init__(self, opt):
117
+ super(AoA_Refiner_Core, self).__init__()
118
+ attn = MultiHeadedDotAttention(opt.num_heads, opt.rnn_size, project_k_v=1, scale=opt.multi_head_scale, do_aoa=opt.refine_aoa, norm_q=0, dropout_aoa=getattr(opt, 'dropout_aoa', 0.3))
119
+ layer = AoA_Refiner_Layer(opt.rnn_size, attn, PositionwiseFeedForward(opt.rnn_size, 2048, 0.1) if opt.use_ff else None, 0.1)
120
+ self.layers = clones(layer, 6)
121
+ self.norm = LayerNorm(layer.size)
122
+
123
+ def forward(self, x, mask):
124
+ for layer in self.layers:
125
+ x = layer(x, mask)
126
+ return self.norm(x)
127
+
128
+ class AoA_Decoder_Core(nn.Module):
129
+ def __init__(self, opt):
130
+ super(AoA_Decoder_Core, self).__init__()
131
+ self.drop_prob_lm = opt.drop_prob_lm
132
+ self.d_model = opt.rnn_size
133
+ self.use_multi_head = opt.use_multi_head
134
+ self.multi_head_scale = opt.multi_head_scale
135
+ self.use_ctx_drop = getattr(opt, 'ctx_drop', 0)
136
+ self.out_res = getattr(opt, 'out_res', 0)
137
+ self.decoder_type = getattr(opt, 'decoder_type', 'AoA')
138
+ self.att_lstm = nn.LSTMCell(opt.input_encoding_size + opt.rnn_size, opt.rnn_size) # we, fc, h^2_t-1
139
+ self.out_drop = nn.Dropout(self.drop_prob_lm)
140
+
141
+ if self.decoder_type == 'AoA':
142
+ # AoA layer
143
+ self.att2ctx = nn.Sequential(nn.Linear(self.d_model * opt.multi_head_scale + opt.rnn_size, 2 * opt.rnn_size), nn.GLU())
144
+ elif self.decoder_type == 'LSTM':
145
+ # LSTM layer
146
+ self.att2ctx = nn.LSTMCell(self.d_model * opt.multi_head_scale + opt.rnn_size, opt.rnn_size)
147
+ else:
148
+ # Base linear layer
149
+ self.att2ctx = nn.Sequential(nn.Linear(self.d_model * opt.multi_head_scale + opt.rnn_size, opt.rnn_size), nn.ReLU())
150
+
151
+ # if opt.use_multi_head == 1: # TODO, not implemented for now
152
+ # self.attention = MultiHeadedAddAttention(opt.num_heads, opt.d_model, scale=opt.multi_head_scale)
153
+ if opt.use_multi_head == 2:
154
+ self.attention = MultiHeadedDotAttention(opt.num_heads, opt.rnn_size, project_k_v=0, scale=opt.multi_head_scale, use_output_layer=0, do_aoa=0, norm_q=1)
155
+ else:
156
+ self.attention = Attention(opt)
157
+
158
+ if self.use_ctx_drop:
159
+ self.ctx_drop = nn.Dropout(self.drop_prob_lm)
160
+ else:
161
+ self.ctx_drop = lambda x :x
162
+
163
+ def forward(self, xt, mean_feats, att_feats, p_att_feats, state, att_masks=None):
164
+ # state[0][1] is the context vector at the last step
165
+ h_att, c_att = self.att_lstm(torch.cat([xt, mean_feats + self.ctx_drop(state[0][1])], 1), (state[0][0], state[1][0]))
166
+
167
+ if self.use_multi_head == 2:
168
+ att = self.attention(h_att, p_att_feats.narrow(2, 0, self.multi_head_scale * self.d_model), p_att_feats.narrow(2, self.multi_head_scale * self.d_model, self.multi_head_scale * self.d_model), att_masks)
169
+ else:
170
+ att = self.attention(h_att, att_feats, p_att_feats, att_masks)
171
+
172
+ ctx_input = torch.cat([att, h_att], 1)
173
+ if self.decoder_type == 'LSTM':
174
+ output, c_logic = self.att2ctx(ctx_input, (state[0][1], state[1][1]))
175
+ state = (torch.stack((h_att, output)), torch.stack((c_att, c_logic)))
176
+ else:
177
+ output = self.att2ctx(ctx_input)
178
+ # save the context vector to state[0][1]
179
+ state = (torch.stack((h_att, output)), torch.stack((c_att, state[1][1])))
180
+
181
+ if self.out_res:
182
+ # add residual connection
183
+ output = output + h_att
184
+
185
+ output = self.out_drop(output)
186
+ return output, state
187
+
188
+ class AoAModel(AttModel):
189
+ def __init__(self, opt):
190
+ super(AoAModel, self).__init__(opt)
191
+ self.num_layers = 2
192
+ # mean pooling
193
+ self.use_mean_feats = getattr(opt, 'mean_feats', 1)
194
+ if opt.use_multi_head == 2:
195
+ del self.ctx2att
196
+ self.ctx2att = nn.Linear(opt.rnn_size, 2 * opt.multi_head_scale * opt.rnn_size)
197
+
198
+ if self.use_mean_feats:
199
+ del self.fc_embed
200
+ if opt.refine:
201
+ self.refiner = AoA_Refiner_Core(opt)
202
+ else:
203
+ self.refiner = lambda x,y : x
204
+ self.core = AoA_Decoder_Core(opt)
205
+
206
+ self.d_model = getattr(opt, 'd_model', opt.input_encoding_size)
207
+
208
+
209
+ def _prepare_feature(self, fc_feats, att_feats, att_masks):
210
+ att_feats, att_masks = self.clip_att(att_feats, att_masks)
211
+
212
+ # embed att feats
213
+ att_feats = pack_wrapper(self.att_embed, att_feats, att_masks)
214
+ att_feats = self.refiner(att_feats, att_masks)
215
+
216
+ if self.use_mean_feats:
217
+ # meaning pooling
218
+ if att_masks is None:
219
+ mean_feats = torch.mean(att_feats, dim=1)
220
+ else:
221
+ mean_feats = (torch.sum(att_feats * att_masks.unsqueeze(-1), 1) / torch.sum(att_masks.unsqueeze(-1), 1))
222
+ else:
223
+ mean_feats = self.fc_embed(fc_feats)
224
+
225
+ # Project the attention feats first to reduce memory and computation.
226
+ p_att_feats = self.ctx2att(att_feats)
227
+
228
+ return mean_feats, att_feats, p_att_feats, att_masks
captioning/models/AttEnsemble.py ADDED
@@ -0,0 +1,90 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # This file is the implementation for ensemble evaluation.
2
+
3
+ from __future__ import absolute_import
4
+ from __future__ import division
5
+ from __future__ import print_function
6
+
7
+ import numpy as np
8
+ import torch
9
+ import torch.nn as nn
10
+ import torch.nn.functional as F
11
+ from torch.autograd import *
12
+
13
+ from .CaptionModel import CaptionModel
14
+ from .AttModel import pack_wrapper, AttModel
15
+
16
+ class AttEnsemble(AttModel):
17
+ def __init__(self, models, weights=None):
18
+ CaptionModel.__init__(self)
19
+ # super(AttEnsemble, self).__init__()
20
+
21
+ self.models = nn.ModuleList(models)
22
+ self.vocab_size = models[0].vocab_size
23
+ self.seq_length = models[0].seq_length
24
+ self.bad_endings_ix = models[0].bad_endings_ix
25
+ self.ss_prob = 0
26
+ weights = weights or [1.0] * len(self.models)
27
+ self.register_buffer('weights', torch.tensor(weights))
28
+
29
+ def init_hidden(self, batch_size):
30
+ state = [m.init_hidden(batch_size) for m in self.models]
31
+ return self.pack_state(state)
32
+
33
+ def pack_state(self, state):
34
+ self.state_lengths = [len(_) for _ in state]
35
+ return sum([list(_) for _ in state], [])
36
+
37
+ def unpack_state(self, state):
38
+ out = []
39
+ for l in self.state_lengths:
40
+ out.append(state[:l])
41
+ state = state[l:]
42
+ return out
43
+
44
+ def embed(self, it):
45
+ return [m.embed(it) for m in self.models]
46
+
47
+ def core(self, *args):
48
+ return zip(*[m.core(*_) for m, _ in zip(self.models, zip(*args))])
49
+
50
+ def get_logprobs_state(self, it, tmp_fc_feats, tmp_att_feats, tmp_p_att_feats, tmp_att_masks, state, output_logsoftmax=1):
51
+ # 'it' contains a word index
52
+ xt = self.embed(it)
53
+
54
+ state = self.unpack_state(state)
55
+ output, state = self.core(xt, tmp_fc_feats, tmp_att_feats, tmp_p_att_feats, state, tmp_att_masks)
56
+ logprobs = torch.stack([F.softmax(m.logit(output[i]), dim=1) for i,m in enumerate(self.models)], 2).mul(self.weights).div(self.weights.sum()).sum(-1).log()
57
+
58
+ return logprobs, self.pack_state(state)
59
+
60
+ def _prepare_feature(self, *args):
61
+ return tuple(zip(*[m._prepare_feature(*args) for m in self.models]))
62
+
63
+ def _old_sample_beam(self, fc_feats, att_feats, att_masks=None, opt={}):
64
+ beam_size = opt.get('beam_size', 10)
65
+ batch_size = fc_feats.size(0)
66
+
67
+ fc_feats, att_feats, p_att_feats, att_masks = self._prepare_feature(fc_feats, att_feats, att_masks)
68
+
69
+ assert beam_size <= self.vocab_size + 1, 'lets assume this for now, otherwise this corner case causes a few headaches down the road. can be dealt with in future if needed'
70
+ seq = torch.LongTensor(self.seq_length, batch_size).zero_()
71
+ seqLogprobs = torch.FloatTensor(self.seq_length, batch_size, self.vocab_size + 1)
72
+ # lets process every image independently for now, for simplicity
73
+
74
+ self.done_beams = [[] for _ in range(batch_size)]
75
+ for k in range(batch_size):
76
+ state = self.init_hidden(beam_size)
77
+ tmp_fc_feats = [fc_feats[i][k:k+1].expand(beam_size, fc_feats[i].size(1)) for i,m in enumerate(self.models)]
78
+ tmp_att_feats = [att_feats[i][k:k+1].expand(*((beam_size,)+att_feats[i].size()[1:])).contiguous() for i,m in enumerate(self.models)]
79
+ tmp_p_att_feats = [p_att_feats[i][k:k+1].expand(*((beam_size,)+p_att_feats[i].size()[1:])).contiguous() for i,m in enumerate(self.models)]
80
+ tmp_att_masks = [att_masks[i][k:k+1].expand(*((beam_size,)+att_masks[i].size()[1:])).contiguous() if att_masks[i] is not None else att_masks[i] for i,m in enumerate(self.models)]
81
+
82
+ it = fc_feats[0].data.new(beam_size).long().zero_()
83
+ logprobs, state = self.get_logprobs_state(it, tmp_fc_feats, tmp_att_feats, tmp_p_att_feats, tmp_att_masks, state)
84
+
85
+ self.done_beams[k] = self.old_beam_search(state, logprobs, tmp_fc_feats, tmp_att_feats, tmp_p_att_feats, tmp_att_masks, opt=opt)
86
+ seq[:, k] = self.done_beams[k][0]['seq'] # the first beam has highest cumulative score
87
+ seqLogprobs[:, k] = self.done_beams[k][0]['logps']
88
+ # return the samples and their log likelihoods
89
+ return seq.transpose(0, 1), seqLogprobs.transpose(0, 1)
90
+ # return the samples and their log likelihoods
captioning/models/AttModel.py ADDED
@@ -0,0 +1,969 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # This file contains Att2in2, AdaAtt, AdaAttMO, UpDown model
2
+
3
+ # AdaAtt is from Knowing When to Look: Adaptive Attention via A Visual Sentinel for Image Captioning
4
+ # https://arxiv.org/abs/1612.01887
5
+ # AdaAttMO is a modified version with maxout lstm
6
+
7
+ # Att2in is from Self-critical Sequence Training for Image Captioning
8
+ # https://arxiv.org/abs/1612.00563
9
+ # In this file we only have Att2in2, which is a slightly different version of att2in,
10
+ # in which the img feature embedding and word embedding is the same as what in adaatt.
11
+
12
+ # UpDown is from Bottom-Up and Top-Down Attention for Image Captioning and VQA
13
+ # https://arxiv.org/abs/1707.07998
14
+ # However, it may not be identical to the author's architecture.
15
+
16
+ from __future__ import absolute_import
17
+ from __future__ import division
18
+ from __future__ import print_function
19
+
20
+ import numpy as np
21
+ import torch
22
+ import torch.nn as nn
23
+ import torch.nn.functional as F
24
+ from . import utils
25
+ from torch.nn.utils.rnn import PackedSequence, pack_padded_sequence, pad_packed_sequence
26
+
27
+ from .CaptionModel import CaptionModel
28
+
29
+ bad_endings = ['a','an','the','in','for','at','of','with','before','after','on','upon','near','to','is','are','am']
30
+ bad_endings += ['the']
31
+
32
+ def sort_pack_padded_sequence(input, lengths):
33
+ sorted_lengths, indices = torch.sort(lengths, descending=True)
34
+ # tmp = pack_padded_sequence(input[indices], sorted_lengths, batch_first=True)
35
+ tmp = pack_padded_sequence(input[indices], sorted_lengths.cpu(), batch_first=True)
36
+ inv_ix = indices.clone()
37
+ inv_ix[indices] = torch.arange(0,len(indices)).type_as(inv_ix)
38
+ return tmp, inv_ix
39
+
40
+ def pad_unsort_packed_sequence(input, inv_ix):
41
+ tmp, _ = pad_packed_sequence(input, batch_first=True)
42
+ tmp = tmp[inv_ix]
43
+ return tmp
44
+
45
+ def pack_wrapper(module, att_feats, att_masks):
46
+ if att_masks is not None:
47
+ packed, inv_ix = sort_pack_padded_sequence(att_feats, att_masks.data.long().sum(1))
48
+ return pad_unsort_packed_sequence(PackedSequence(module(packed[0]), packed[1]), inv_ix)
49
+ else:
50
+ return module(att_feats)
51
+
52
+ class AttModel(CaptionModel):
53
+ def __init__(self, opt):
54
+ super(AttModel, self).__init__()
55
+ self.vocab_size = opt.vocab_size
56
+ self.input_encoding_size = opt.input_encoding_size
57
+ #self.rnn_type = opt.rnn_type
58
+ self.rnn_size = opt.rnn_size
59
+ self.num_layers = opt.num_layers
60
+ self.drop_prob_lm = opt.drop_prob_lm
61
+ self.seq_length = getattr(opt, 'max_length', 20) or opt.seq_length # maximum sample length
62
+ self.fc_feat_size = opt.fc_feat_size
63
+ self.att_feat_size = opt.att_feat_size
64
+ self.att_hid_size = opt.att_hid_size
65
+
66
+ self.bos_idx = getattr(opt, 'bos_idx', 0)
67
+ self.eos_idx = getattr(opt, 'eos_idx', 0)
68
+ self.pad_idx = getattr(opt, 'pad_idx', 0)
69
+
70
+ self.use_bn = getattr(opt, 'use_bn', 0)
71
+
72
+ self.ss_prob = 0.0 # Schedule sampling probability
73
+
74
+ self.embed = nn.Sequential(nn.Embedding(self.vocab_size + 1, self.input_encoding_size),
75
+ nn.ReLU(),
76
+ nn.Dropout(self.drop_prob_lm))
77
+ self.fc_embed = nn.Sequential(nn.Linear(self.fc_feat_size, self.rnn_size),
78
+ nn.ReLU(),
79
+ nn.Dropout(self.drop_prob_lm))
80
+ self.att_embed = nn.Sequential(*(
81
+ ((nn.BatchNorm1d(self.att_feat_size),) if self.use_bn else ())+
82
+ (nn.Linear(self.att_feat_size, self.rnn_size),
83
+ nn.ReLU(),
84
+ nn.Dropout(self.drop_prob_lm))+
85
+ ((nn.BatchNorm1d(self.rnn_size),) if self.use_bn==2 else ())))
86
+
87
+ self.logit_layers = getattr(opt, 'logit_layers', 1)
88
+ if self.logit_layers == 1:
89
+ self.logit = nn.Linear(self.rnn_size, self.vocab_size + 1)
90
+ else:
91
+ self.logit = [[nn.Linear(self.rnn_size, self.rnn_size), nn.ReLU(), nn.Dropout(0.5)] for _ in range(opt.logit_layers - 1)]
92
+ self.logit = nn.Sequential(*(reduce(lambda x,y:x+y, self.logit) + [nn.Linear(self.rnn_size, self.vocab_size + 1)]))
93
+ self.ctx2att = nn.Linear(self.rnn_size, self.att_hid_size)
94
+
95
+ # For remove bad endding
96
+ self.vocab = opt.vocab
97
+ self.bad_endings_ix = [int(k) for k,v in self.vocab.items() if v in bad_endings]
98
+
99
+ def init_hidden(self, bsz):
100
+ weight = self.logit.weight \
101
+ if hasattr(self.logit, "weight") \
102
+ else self.logit[0].weight
103
+ return (weight.new_zeros(self.num_layers, bsz, self.rnn_size),
104
+ weight.new_zeros(self.num_layers, bsz, self.rnn_size))
105
+
106
+ def clip_att(self, att_feats, att_masks):
107
+ # Clip the length of att_masks and att_feats to the maximum length
108
+ if att_masks is not None:
109
+ max_len = att_masks.data.long().sum(1).max()
110
+ att_feats = att_feats[:, :max_len].contiguous()
111
+ att_masks = att_masks[:, :max_len].contiguous()
112
+ return att_feats, att_masks
113
+
114
+ def _prepare_feature(self, fc_feats, att_feats, att_masks):
115
+ att_feats, att_masks = self.clip_att(att_feats, att_masks)
116
+
117
+ # embed fc and att feats
118
+ fc_feats = self.fc_embed(fc_feats)
119
+ att_feats = pack_wrapper(self.att_embed, att_feats, att_masks)
120
+
121
+ # Project the attention feats first to reduce memory and computation comsumptions.
122
+ p_att_feats = self.ctx2att(att_feats)
123
+
124
+ return fc_feats, att_feats, p_att_feats, att_masks
125
+
126
+ def _forward(self, fc_feats, att_feats, seq, att_masks=None):
127
+ batch_size = fc_feats.size(0)
128
+ if seq.ndim == 3: # B * seq_per_img * seq_len
129
+ seq = seq.reshape(-1, seq.shape[2])
130
+ seq_per_img = seq.shape[0] // batch_size
131
+ state = self.init_hidden(batch_size*seq_per_img)
132
+
133
+ outputs = fc_feats.new_zeros(batch_size*seq_per_img, seq.size(1), self.vocab_size+1)
134
+
135
+ # Prepare the features
136
+ p_fc_feats, p_att_feats, pp_att_feats, p_att_masks = self._prepare_feature(fc_feats, att_feats, att_masks)
137
+ # pp_att_feats is used for attention, we cache it in advance to reduce computation cost
138
+
139
+ if seq_per_img > 1:
140
+ p_fc_feats, p_att_feats, pp_att_feats, p_att_masks = utils.repeat_tensors(seq_per_img,
141
+ [p_fc_feats, p_att_feats, pp_att_feats, p_att_masks]
142
+ )
143
+
144
+ for i in range(seq.size(1)):
145
+ if self.training and i >= 1 and self.ss_prob > 0.0: # otherwiste no need to sample
146
+ sample_prob = fc_feats.new(batch_size*seq_per_img).uniform_(0, 1)
147
+ sample_mask = sample_prob < self.ss_prob
148
+ if sample_mask.sum() == 0:
149
+ it = seq[:, i].clone()
150
+ else:
151
+ sample_ind = sample_mask.nonzero().view(-1)
152
+ it = seq[:, i].data.clone()
153
+ prob_prev = torch.exp(outputs[:, i-1].detach()) # fetch prev distribution: shape Nx(M+1)
154
+ it.index_copy_(0, sample_ind, torch.multinomial(prob_prev, 1).view(-1).index_select(0, sample_ind))
155
+ else:
156
+ it = seq[:, i].clone()
157
+ # break if all the sequences end
158
+ if i >= 1 and seq[:, i].sum() == 0:
159
+ break
160
+
161
+ output, state = self.get_logprobs_state(it, p_fc_feats, p_att_feats, pp_att_feats, p_att_masks, state)
162
+ outputs[:, i] = output
163
+
164
+ return outputs
165
+
166
+ def get_logprobs_state(self, it, fc_feats, att_feats, p_att_feats, att_masks, state, output_logsoftmax=1):
167
+ # 'it' contains a word index
168
+ xt = self.embed(it)
169
+
170
+ output, state = self.core(xt, fc_feats, att_feats, p_att_feats, state, att_masks)
171
+ if output_logsoftmax:
172
+ logprobs = F.log_softmax(self.logit(output), dim=1)
173
+ else:
174
+ logprobs = self.logit(output)
175
+
176
+ return logprobs, state
177
+
178
+ def _old_sample_beam(self, fc_feats, att_feats, att_masks=None, opt={}):
179
+ beam_size = opt.get('beam_size', 10)
180
+ group_size = opt.get('group_size', 1)
181
+ sample_n = opt.get('sample_n', 10)
182
+ # when sample_n == beam_size then each beam is a sample.
183
+ assert sample_n == 1 or sample_n == beam_size // group_size, 'when beam search, sample_n == 1 or beam search'
184
+ batch_size = fc_feats.size(0)
185
+
186
+ p_fc_feats, p_att_feats, pp_att_feats, p_att_masks = self._prepare_feature(fc_feats, att_feats, att_masks)
187
+
188
+ assert beam_size <= self.vocab_size + 1, 'lets assume this for now, otherwise this corner case causes a few headaches down the road. can be dealt with in future if needed'
189
+ seq = fc_feats.new_full((batch_size*sample_n, self.seq_length), self.pad_idx, dtype=torch.long)
190
+ seqLogprobs = fc_feats.new_zeros(batch_size*sample_n, self.seq_length, self.vocab_size + 1)
191
+ # lets process every image independently for now, for simplicity
192
+
193
+ self.done_beams = [[] for _ in range(batch_size)]
194
+ for k in range(batch_size):
195
+ state = self.init_hidden(beam_size)
196
+ tmp_fc_feats, tmp_att_feats, tmp_p_att_feats, tmp_att_masks = utils.repeat_tensors(beam_size,
197
+ [p_fc_feats[k:k+1], p_att_feats[k:k+1], pp_att_feats[k:k+1], p_att_masks[k:k+1] if att_masks is not None else None]
198
+ )
199
+
200
+ for t in range(1):
201
+ if t == 0: # input <bos>
202
+ it = fc_feats.new_full([beam_size], self.bos_idx, dtype=torch.long)
203
+
204
+ logprobs, state = self.get_logprobs_state(it, tmp_fc_feats, tmp_att_feats, tmp_p_att_feats, tmp_att_masks, state)
205
+
206
+ self.done_beams[k] = self.old_beam_search(state, logprobs, tmp_fc_feats, tmp_att_feats, tmp_p_att_feats, tmp_att_masks, opt=opt)
207
+ if sample_n == beam_size:
208
+ for _n in range(sample_n):
209
+ seq[k*sample_n+_n, :] = self.done_beams[k][_n]['seq']
210
+ seqLogprobs[k*sample_n+_n, :] = self.done_beams[k][_n]['logps']
211
+ else:
212
+ seq[k, :] = self.done_beams[k][0]['seq'] # the first beam has highest cumulative score
213
+ seqLogprobs[k, :] = self.done_beams[k][0]['logps']
214
+ # return the samples and their log likelihoods
215
+ return seq, seqLogprobs
216
+
217
+
218
+ def _sample_beam(self, fc_feats, att_feats, att_masks=None, opt={}):
219
+ beam_size = opt.get('beam_size', 10)
220
+ group_size = opt.get('group_size', 1)
221
+ sample_n = opt.get('sample_n', 10)
222
+ # when sample_n == beam_size then each beam is a sample.
223
+ assert sample_n == 1 or sample_n == beam_size // group_size, 'when beam search, sample_n == 1 or beam search'
224
+ batch_size = fc_feats.size(0)
225
+
226
+ p_fc_feats, p_att_feats, pp_att_feats, p_att_masks = self._prepare_feature(fc_feats, att_feats, att_masks)
227
+
228
+ assert beam_size <= self.vocab_size + 1, 'lets assume this for now, otherwise this corner case causes a few headaches down the road. can be dealt with in future if needed'
229
+ seq = fc_feats.new_full((batch_size*sample_n, self.seq_length), self.pad_idx, dtype=torch.long)
230
+ seqLogprobs = fc_feats.new_zeros(batch_size*sample_n, self.seq_length, self.vocab_size + 1)
231
+ # lets process every image independently for now, for simplicity
232
+
233
+ self.done_beams = [[] for _ in range(batch_size)]
234
+
235
+ state = self.init_hidden(batch_size)
236
+
237
+ # first step, feed bos
238
+ it = fc_feats.new_full([batch_size], self.bos_idx, dtype=torch.long)
239
+ logprobs, state = self.get_logprobs_state(it, p_fc_feats, p_att_feats, pp_att_feats, p_att_masks, state)
240
+
241
+ p_fc_feats, p_att_feats, pp_att_feats, p_att_masks = utils.repeat_tensors(beam_size,
242
+ [p_fc_feats, p_att_feats, pp_att_feats, p_att_masks]
243
+ )
244
+ self.done_beams = self.beam_search(state, logprobs, p_fc_feats, p_att_feats, pp_att_feats, p_att_masks, opt=opt)
245
+ for k in range(batch_size):
246
+ if sample_n == beam_size:
247
+ for _n in range(sample_n):
248
+ seq_len = self.done_beams[k][_n]['seq'].shape[0]
249
+ seq[k*sample_n+_n, :seq_len] = self.done_beams[k][_n]['seq']
250
+ seqLogprobs[k*sample_n+_n, :seq_len] = self.done_beams[k][_n]['logps']
251
+ else:
252
+ seq_len = self.done_beams[k][0]['seq'].shape[0]
253
+ seq[k, :seq_len] = self.done_beams[k][0]['seq'] # the first beam has highest cumulative score
254
+ seqLogprobs[k, :seq_len] = self.done_beams[k][0]['logps']
255
+ # return the samples and their log likelihoods
256
+ return seq, seqLogprobs
257
+
258
+ def _sample(self, fc_feats, att_feats, att_masks=None, opt={}):
259
+
260
+ sample_method = opt.get('sample_method', 'greedy')
261
+ beam_size = opt.get('beam_size', 1)
262
+ temperature = opt.get('temperature', 1.0)
263
+ sample_n = int(opt.get('sample_n', 1))
264
+ group_size = opt.get('group_size', 1)
265
+ output_logsoftmax = opt.get('output_logsoftmax', 1)
266
+ decoding_constraint = opt.get('decoding_constraint', 0)
267
+ block_trigrams = opt.get('block_trigrams', 0)
268
+ remove_bad_endings = opt.get('remove_bad_endings', 0)
269
+ if beam_size > 1 and sample_method in ['greedy', 'beam_search']:
270
+ return self._sample_beam(fc_feats, att_feats, att_masks, opt)
271
+ if group_size > 1:
272
+ return self._diverse_sample(fc_feats, att_feats, att_masks, opt)
273
+
274
+ batch_size = fc_feats.size(0)
275
+ state = self.init_hidden(batch_size*sample_n)
276
+
277
+ p_fc_feats, p_att_feats, pp_att_feats, p_att_masks = self._prepare_feature(fc_feats, att_feats, att_masks)
278
+
279
+ if sample_n > 1:
280
+ p_fc_feats, p_att_feats, pp_att_feats, p_att_masks = utils.repeat_tensors(sample_n,
281
+ [p_fc_feats, p_att_feats, pp_att_feats, p_att_masks]
282
+ )
283
+
284
+ trigrams = [] # will be a list of batch_size dictionaries
285
+
286
+ seq = fc_feats.new_full((batch_size*sample_n, self.seq_length), self.pad_idx, dtype=torch.long)
287
+ seqLogprobs = fc_feats.new_zeros(batch_size*sample_n, self.seq_length, self.vocab_size + 1)
288
+ for t in range(self.seq_length + 1):
289
+ if t == 0: # input <bos>
290
+ it = fc_feats.new_full([batch_size*sample_n], self.bos_idx, dtype=torch.long)
291
+
292
+ logprobs, state = self.get_logprobs_state(it, p_fc_feats, p_att_feats, pp_att_feats, p_att_masks, state, output_logsoftmax=output_logsoftmax)
293
+
294
+ if decoding_constraint and t > 0:
295
+ tmp = logprobs.new_zeros(logprobs.size())
296
+ tmp.scatter_(1, seq[:,t-1].data.unsqueeze(1), float('-inf'))
297
+ logprobs = logprobs + tmp
298
+
299
+ if remove_bad_endings and t > 0:
300
+ tmp = logprobs.new_zeros(logprobs.size())
301
+ prev_bad = np.isin(seq[:,t-1].data.cpu().numpy(), self.bad_endings_ix)
302
+ # Make it impossible to generate bad_endings
303
+ tmp[torch.from_numpy(prev_bad.astype('uint8')), 0] = float('-inf')
304
+ logprobs = logprobs + tmp
305
+
306
+ # Mess with trigrams
307
+ # Copy from https://github.com/lukemelas/image-paragraph-captioning
308
+ if block_trigrams and t >= 3:
309
+ # Store trigram generated at last step
310
+ prev_two_batch = seq[:,t-3:t-1]
311
+ for i in range(batch_size): # = seq.size(0)
312
+ prev_two = (prev_two_batch[i][0].item(), prev_two_batch[i][1].item())
313
+ current = seq[i][t-1]
314
+ if t == 3: # initialize
315
+ trigrams.append({prev_two: [current]}) # {LongTensor: list containing 1 int}
316
+ elif t > 3:
317
+ if prev_two in trigrams[i]: # add to list
318
+ trigrams[i][prev_two].append(current)
319
+ else: # create list
320
+ trigrams[i][prev_two] = [current]
321
+ # Block used trigrams at next step
322
+ prev_two_batch = seq[:,t-2:t]
323
+ mask = torch.zeros(logprobs.size(), requires_grad=False).to(logprobs.device) # batch_size x vocab_size
324
+ for i in range(batch_size):
325
+ prev_two = (prev_two_batch[i][0].item(), prev_two_batch[i][1].item())
326
+ if prev_two in trigrams[i]:
327
+ for j in trigrams[i][prev_two]:
328
+ mask[i,j] += 1
329
+ # Apply mask to log probs
330
+ #logprobs = logprobs - (mask * 1e9)
331
+ alpha = 2.0 # = 4
332
+ logprobs = logprobs + (mask * -0.693 * alpha) # ln(1/2) * alpha (alpha -> infty works best)
333
+
334
+ # sample the next word
335
+ if t == self.seq_length: # skip if we achieve maximum length
336
+ break
337
+ it, sampleLogprobs = self.sample_next_word(logprobs, sample_method, temperature)
338
+
339
+ # stop when all finished
340
+ if t == 0:
341
+ unfinished = it != self.eos_idx
342
+ else:
343
+ it[~unfinished] = self.pad_idx # This allows eos_idx not being overwritten to 0
344
+ logprobs = logprobs * unfinished.unsqueeze(1).to(logprobs)
345
+ unfinished = unfinished & (it != self.eos_idx)
346
+ seq[:,t] = it
347
+ seqLogprobs[:,t] = logprobs
348
+ # quit loop if all sequences have finished
349
+ if unfinished.sum() == 0:
350
+ break
351
+
352
+ return seq, seqLogprobs
353
+
354
+ def _diverse_sample(self, fc_feats, att_feats, att_masks=None, opt={}):
355
+
356
+ sample_method = opt.get('sample_method', 'greedy')
357
+ beam_size = opt.get('beam_size', 1)
358
+ temperature = opt.get('temperature', 1.0)
359
+ group_size = opt.get('group_size', 1)
360
+ diversity_lambda = opt.get('diversity_lambda', 0.5)
361
+ decoding_constraint = opt.get('decoding_constraint', 0)
362
+ block_trigrams = opt.get('block_trigrams', 0)
363
+ remove_bad_endings = opt.get('remove_bad_endings', 0)
364
+
365
+ batch_size = fc_feats.size(0)
366
+ state = self.init_hidden(batch_size)
367
+
368
+ p_fc_feats, p_att_feats, pp_att_feats, p_att_masks = self._prepare_feature(fc_feats, att_feats, att_masks)
369
+
370
+ trigrams_table = [[] for _ in range(group_size)] # will be a list of batch_size dictionaries
371
+
372
+ seq_table = [fc_feats.new_full((batch_size, self.seq_length), self.pad_idx, dtype=torch.long) for _ in range(group_size)]
373
+ seqLogprobs_table = [fc_feats.new_zeros(batch_size, self.seq_length) for _ in range(group_size)]
374
+ state_table = [self.init_hidden(batch_size) for _ in range(group_size)]
375
+
376
+ for tt in range(self.seq_length + group_size):
377
+ for divm in range(group_size):
378
+ t = tt - divm
379
+ seq = seq_table[divm]
380
+ seqLogprobs = seqLogprobs_table[divm]
381
+ trigrams = trigrams_table[divm]
382
+ if t >= 0 and t <= self.seq_length-1:
383
+ if t == 0: # input <bos>
384
+ it = fc_feats.new_full([batch_size], self.bos_idx, dtype=torch.long)
385
+ else:
386
+ it = seq[:, t-1] # changed
387
+
388
+ logprobs, state_table[divm] = self.get_logprobs_state(it, p_fc_feats, p_att_feats, pp_att_feats, p_att_masks, state_table[divm]) # changed
389
+ logprobs = F.log_softmax(logprobs / temperature, dim=-1)
390
+
391
+ # Add diversity
392
+ if divm > 0:
393
+ unaug_logprobs = logprobs.clone()
394
+ for prev_choice in range(divm):
395
+ prev_decisions = seq_table[prev_choice][:, t]
396
+ logprobs[:, prev_decisions] = logprobs[:, prev_decisions] - diversity_lambda
397
+
398
+ if decoding_constraint and t > 0:
399
+ tmp = logprobs.new_zeros(logprobs.size())
400
+ tmp.scatter_(1, seq[:,t-1].data.unsqueeze(1), float('-inf'))
401
+ logprobs = logprobs + tmp
402
+
403
+ if remove_bad_endings and t > 0:
404
+ tmp = logprobs.new_zeros(logprobs.size())
405
+ prev_bad = np.isin(seq[:,t-1].data.cpu().numpy(), self.bad_endings_ix)
406
+ # Impossible to generate remove_bad_endings
407
+ tmp[torch.from_numpy(prev_bad.astype('uint8')), 0] = float('-inf')
408
+ logprobs = logprobs + tmp
409
+
410
+ # Mess with trigrams
411
+ if block_trigrams and t >= 3:
412
+ # Store trigram generated at last step
413
+ prev_two_batch = seq[:,t-3:t-1]
414
+ for i in range(batch_size): # = seq.size(0)
415
+ prev_two = (prev_two_batch[i][0].item(), prev_two_batch[i][1].item())
416
+ current = seq[i][t-1]
417
+ if t == 3: # initialize
418
+ trigrams.append({prev_two: [current]}) # {LongTensor: list containing 1 int}
419
+ elif t > 3:
420
+ if prev_two in trigrams[i]: # add to list
421
+ trigrams[i][prev_two].append(current)
422
+ else: # create list
423
+ trigrams[i][prev_two] = [current]
424
+ # Block used trigrams at next step
425
+ prev_two_batch = seq[:,t-2:t]
426
+ mask = torch.zeros(logprobs.size(), requires_grad=False).cuda() # batch_size x vocab_size
427
+ for i in range(batch_size):
428
+ prev_two = (prev_two_batch[i][0].item(), prev_two_batch[i][1].item())
429
+ if prev_two in trigrams[i]:
430
+ for j in trigrams[i][prev_two]:
431
+ mask[i,j] += 1
432
+ # Apply mask to log probs
433
+ #logprobs = logprobs - (mask * 1e9)
434
+ alpha = 2.0 # = 4
435
+ logprobs = logprobs + (mask * -0.693 * alpha) # ln(1/2) * alpha (alpha -> infty works best)
436
+
437
+ it, sampleLogprobs = self.sample_next_word(logprobs, sample_method, 1)
438
+
439
+ # stop when all finished
440
+ if t == 0:
441
+ unfinished = it != self.eos_idx
442
+ else:
443
+ unfinished = (seq[:,t-1] != self.pad_idx) & (seq[:,t-1] != self.eos_idx)
444
+ it[~unfinished] = self.pad_idx
445
+ unfinished = unfinished & (it != self.eos_idx) # changed
446
+ seq[:,t] = it
447
+ seqLogprobs[:,t] = sampleLogprobs.view(-1)
448
+
449
+ return torch.stack(seq_table, 1).reshape(batch_size * group_size, -1), torch.stack(seqLogprobs_table, 1).reshape(batch_size * group_size, -1)
450
+
451
+ class AdaAtt_lstm(nn.Module):
452
+ def __init__(self, opt, use_maxout=True):
453
+ super(AdaAtt_lstm, self).__init__()
454
+ self.input_encoding_size = opt.input_encoding_size
455
+ #self.rnn_type = opt.rnn_type
456
+ self.rnn_size = opt.rnn_size
457
+ self.num_layers = opt.num_layers
458
+ self.drop_prob_lm = opt.drop_prob_lm
459
+ self.fc_feat_size = opt.fc_feat_size
460
+ self.att_feat_size = opt.att_feat_size
461
+ self.att_hid_size = opt.att_hid_size
462
+
463
+ self.use_maxout = use_maxout
464
+
465
+ # Build a LSTM
466
+ self.w2h = nn.Linear(self.input_encoding_size, (4+(use_maxout==True)) * self.rnn_size)
467
+ self.v2h = nn.Linear(self.rnn_size, (4+(use_maxout==True)) * self.rnn_size)
468
+
469
+ self.i2h = nn.ModuleList([nn.Linear(self.rnn_size, (4+(use_maxout==True)) * self.rnn_size) for _ in range(self.num_layers - 1)])
470
+ self.h2h = nn.ModuleList([nn.Linear(self.rnn_size, (4+(use_maxout==True)) * self.rnn_size) for _ in range(self.num_layers)])
471
+
472
+ # Layers for getting the fake region
473
+ if self.num_layers == 1:
474
+ self.r_w2h = nn.Linear(self.input_encoding_size, self.rnn_size)
475
+ self.r_v2h = nn.Linear(self.rnn_size, self.rnn_size)
476
+ else:
477
+ self.r_i2h = nn.Linear(self.rnn_size, self.rnn_size)
478
+ self.r_h2h = nn.Linear(self.rnn_size, self.rnn_size)
479
+
480
+
481
+ def forward(self, xt, img_fc, state):
482
+
483
+ hs = []
484
+ cs = []
485
+ for L in range(self.num_layers):
486
+ # c,h from previous timesteps
487
+ prev_h = state[0][L]
488
+ prev_c = state[1][L]
489
+ # the input to this layer
490
+ if L == 0:
491
+ x = xt
492
+ i2h = self.w2h(x) + self.v2h(img_fc)
493
+ else:
494
+ x = hs[-1]
495
+ x = F.dropout(x, self.drop_prob_lm, self.training)
496
+ i2h = self.i2h[L-1](x)
497
+
498
+ all_input_sums = i2h+self.h2h[L](prev_h)
499
+
500
+ sigmoid_chunk = all_input_sums.narrow(1, 0, 3 * self.rnn_size)
501
+ sigmoid_chunk = torch.sigmoid(sigmoid_chunk)
502
+ # decode the gates
503
+ in_gate = sigmoid_chunk.narrow(1, 0, self.rnn_size)
504
+ forget_gate = sigmoid_chunk.narrow(1, self.rnn_size, self.rnn_size)
505
+ out_gate = sigmoid_chunk.narrow(1, self.rnn_size * 2, self.rnn_size)
506
+ # decode the write inputs
507
+ if not self.use_maxout:
508
+ in_transform = torch.tanh(all_input_sums.narrow(1, 3 * self.rnn_size, self.rnn_size))
509
+ else:
510
+ in_transform = all_input_sums.narrow(1, 3 * self.rnn_size, 2 * self.rnn_size)
511
+ in_transform = torch.max(\
512
+ in_transform.narrow(1, 0, self.rnn_size),
513
+ in_transform.narrow(1, self.rnn_size, self.rnn_size))
514
+ # perform the LSTM update
515
+ next_c = forget_gate * prev_c + in_gate * in_transform
516
+ # gated cells form the output
517
+ tanh_nex_c = torch.tanh(next_c)
518
+ next_h = out_gate * tanh_nex_c
519
+ if L == self.num_layers-1:
520
+ if L == 0:
521
+ i2h = self.r_w2h(x) + self.r_v2h(img_fc)
522
+ else:
523
+ i2h = self.r_i2h(x)
524
+ n5 = i2h+self.r_h2h(prev_h)
525
+ fake_region = torch.sigmoid(n5) * tanh_nex_c
526
+
527
+ cs.append(next_c)
528
+ hs.append(next_h)
529
+
530
+ # set up the decoder
531
+ top_h = hs[-1]
532
+ top_h = F.dropout(top_h, self.drop_prob_lm, self.training)
533
+ fake_region = F.dropout(fake_region, self.drop_prob_lm, self.training)
534
+
535
+ state = (torch.cat([_.unsqueeze(0) for _ in hs], 0),
536
+ torch.cat([_.unsqueeze(0) for _ in cs], 0))
537
+ return top_h, fake_region, state
538
+
539
+ class AdaAtt_attention(nn.Module):
540
+ def __init__(self, opt):
541
+ super(AdaAtt_attention, self).__init__()
542
+ self.input_encoding_size = opt.input_encoding_size
543
+ #self.rnn_type = opt.rnn_type
544
+ self.rnn_size = opt.rnn_size
545
+ self.drop_prob_lm = opt.drop_prob_lm
546
+ self.att_hid_size = opt.att_hid_size
547
+
548
+ # fake region embed
549
+ self.fr_linear = nn.Sequential(
550
+ nn.Linear(self.rnn_size, self.input_encoding_size),
551
+ nn.ReLU(),
552
+ nn.Dropout(self.drop_prob_lm))
553
+ self.fr_embed = nn.Linear(self.input_encoding_size, self.att_hid_size)
554
+
555
+ # h out embed
556
+ self.ho_linear = nn.Sequential(
557
+ nn.Linear(self.rnn_size, self.input_encoding_size),
558
+ nn.Tanh(),
559
+ nn.Dropout(self.drop_prob_lm))
560
+ self.ho_embed = nn.Linear(self.input_encoding_size, self.att_hid_size)
561
+
562
+ self.alpha_net = nn.Linear(self.att_hid_size, 1)
563
+ self.att2h = nn.Linear(self.rnn_size, self.rnn_size)
564
+
565
+ def forward(self, h_out, fake_region, conv_feat, conv_feat_embed, att_masks=None):
566
+
567
+ # View into three dimensions
568
+ att_size = conv_feat.numel() // conv_feat.size(0) // self.rnn_size
569
+ conv_feat = conv_feat.view(-1, att_size, self.rnn_size)
570
+ conv_feat_embed = conv_feat_embed.view(-1, att_size, self.att_hid_size)
571
+
572
+ # view neighbor from bach_size * neighbor_num x rnn_size to bach_size x rnn_size * neighbor_num
573
+ fake_region = self.fr_linear(fake_region)
574
+ fake_region_embed = self.fr_embed(fake_region)
575
+
576
+ h_out_linear = self.ho_linear(h_out)
577
+ h_out_embed = self.ho_embed(h_out_linear)
578
+
579
+ txt_replicate = h_out_embed.unsqueeze(1).expand(h_out_embed.size(0), att_size + 1, h_out_embed.size(1))
580
+
581
+ img_all = torch.cat([fake_region.view(-1,1,self.input_encoding_size), conv_feat], 1)
582
+ img_all_embed = torch.cat([fake_region_embed.view(-1,1,self.input_encoding_size), conv_feat_embed], 1)
583
+
584
+ hA = torch.tanh(img_all_embed + txt_replicate)
585
+ hA = F.dropout(hA,self.drop_prob_lm, self.training)
586
+
587
+ hAflat = self.alpha_net(hA.view(-1, self.att_hid_size))
588
+ PI = F.softmax(hAflat.view(-1, att_size + 1), dim=1)
589
+
590
+ if att_masks is not None:
591
+ att_masks = att_masks.view(-1, att_size)
592
+ PI = PI * torch.cat([att_masks[:,:1], att_masks], 1) # assume one one at the first time step.
593
+ PI = PI / PI.sum(1, keepdim=True)
594
+
595
+ visAtt = torch.bmm(PI.unsqueeze(1), img_all)
596
+ visAttdim = visAtt.squeeze(1)
597
+
598
+ atten_out = visAttdim + h_out_linear
599
+
600
+ h = torch.tanh(self.att2h(atten_out))
601
+ h = F.dropout(h, self.drop_prob_lm, self.training)
602
+ return h
603
+
604
+ class AdaAttCore(nn.Module):
605
+ def __init__(self, opt, use_maxout=False):
606
+ super(AdaAttCore, self).__init__()
607
+ self.lstm = AdaAtt_lstm(opt, use_maxout)
608
+ self.attention = AdaAtt_attention(opt)
609
+
610
+ def forward(self, xt, fc_feats, att_feats, p_att_feats, state, att_masks=None):
611
+ h_out, p_out, state = self.lstm(xt, fc_feats, state)
612
+ atten_out = self.attention(h_out, p_out, att_feats, p_att_feats, att_masks)
613
+ return atten_out, state
614
+
615
+ class UpDownCore(nn.Module):
616
+ def __init__(self, opt, use_maxout=False):
617
+ super(UpDownCore, self).__init__()
618
+ self.drop_prob_lm = opt.drop_prob_lm
619
+
620
+ self.att_lstm = nn.LSTMCell(opt.input_encoding_size + opt.rnn_size * 2, opt.rnn_size) # we, fc, h^2_t-1
621
+ self.lang_lstm = nn.LSTMCell(opt.rnn_size * 2, opt.rnn_size) # h^1_t, \hat v
622
+ self.attention = Attention(opt)
623
+
624
+ def forward(self, xt, fc_feats, att_feats, p_att_feats, state, att_masks=None):
625
+ prev_h = state[0][-1]
626
+ att_lstm_input = torch.cat([prev_h, fc_feats, xt], 1)
627
+
628
+ h_att, c_att = self.att_lstm(att_lstm_input, (state[0][0], state[1][0]))
629
+
630
+ att = self.attention(h_att, att_feats, p_att_feats, att_masks)
631
+
632
+ lang_lstm_input = torch.cat([att, h_att], 1)
633
+ # lang_lstm_input = torch.cat([att, F.dropout(h_att, self.drop_prob_lm, self.training)], 1) ?????
634
+
635
+ h_lang, c_lang = self.lang_lstm(lang_lstm_input, (state[0][1], state[1][1]))
636
+
637
+ output = F.dropout(h_lang, self.drop_prob_lm, self.training)
638
+ state = (torch.stack([h_att, h_lang]), torch.stack([c_att, c_lang]))
639
+
640
+ return output, state
641
+
642
+
643
+ ############################################################################
644
+ # Notice:
645
+ # StackAtt and DenseAtt are models that I randomly designed.
646
+ # They are not related to any paper.
647
+ ############################################################################
648
+
649
+ from .FCModel import LSTMCore
650
+ class StackAttCore(nn.Module):
651
+ def __init__(self, opt, use_maxout=False):
652
+ super(StackAttCore, self).__init__()
653
+ self.drop_prob_lm = opt.drop_prob_lm
654
+
655
+ # self.att0 = Attention(opt)
656
+ self.att1 = Attention(opt)
657
+ self.att2 = Attention(opt)
658
+
659
+ opt_input_encoding_size = opt.input_encoding_size
660
+ opt.input_encoding_size = opt.input_encoding_size + opt.rnn_size
661
+ self.lstm0 = LSTMCore(opt) # att_feat + word_embedding
662
+ opt.input_encoding_size = opt.rnn_size * 2
663
+ self.lstm1 = LSTMCore(opt)
664
+ self.lstm2 = LSTMCore(opt)
665
+ opt.input_encoding_size = opt_input_encoding_size
666
+
667
+ # self.emb1 = nn.Linear(opt.rnn_size, opt.rnn_size)
668
+ self.emb2 = nn.Linear(opt.rnn_size, opt.rnn_size)
669
+
670
+ def forward(self, xt, fc_feats, att_feats, p_att_feats, state, att_masks=None):
671
+ # att_res_0 = self.att0(state[0][-1], att_feats, p_att_feats, att_masks)
672
+ h_0, state_0 = self.lstm0(torch.cat([xt,fc_feats],1), [state[0][0:1], state[1][0:1]])
673
+ att_res_1 = self.att1(h_0, att_feats, p_att_feats, att_masks)
674
+ h_1, state_1 = self.lstm1(torch.cat([h_0,att_res_1],1), [state[0][1:2], state[1][1:2]])
675
+ att_res_2 = self.att2(h_1 + self.emb2(att_res_1), att_feats, p_att_feats, att_masks)
676
+ h_2, state_2 = self.lstm2(torch.cat([h_1,att_res_2],1), [state[0][2:3], state[1][2:3]])
677
+
678
+ return h_2, [torch.cat(_, 0) for _ in zip(state_0, state_1, state_2)]
679
+
680
+ class DenseAttCore(nn.Module):
681
+ def __init__(self, opt, use_maxout=False):
682
+ super(DenseAttCore, self).__init__()
683
+ self.drop_prob_lm = opt.drop_prob_lm
684
+
685
+ # self.att0 = Attention(opt)
686
+ self.att1 = Attention(opt)
687
+ self.att2 = Attention(opt)
688
+
689
+ opt_input_encoding_size = opt.input_encoding_size
690
+ opt.input_encoding_size = opt.input_encoding_size + opt.rnn_size
691
+ self.lstm0 = LSTMCore(opt) # att_feat + word_embedding
692
+ opt.input_encoding_size = opt.rnn_size * 2
693
+ self.lstm1 = LSTMCore(opt)
694
+ self.lstm2 = LSTMCore(opt)
695
+ opt.input_encoding_size = opt_input_encoding_size
696
+
697
+ # self.emb1 = nn.Linear(opt.rnn_size, opt.rnn_size)
698
+ self.emb2 = nn.Linear(opt.rnn_size, opt.rnn_size)
699
+
700
+ # fuse h_0 and h_1
701
+ self.fusion1 = nn.Sequential(nn.Linear(opt.rnn_size*2, opt.rnn_size),
702
+ nn.ReLU(),
703
+ nn.Dropout(opt.drop_prob_lm))
704
+ # fuse h_0, h_1 and h_2
705
+ self.fusion2 = nn.Sequential(nn.Linear(opt.rnn_size*3, opt.rnn_size),
706
+ nn.ReLU(),
707
+ nn.Dropout(opt.drop_prob_lm))
708
+
709
+ def forward(self, xt, fc_feats, att_feats, p_att_feats, state, att_masks=None):
710
+ # att_res_0 = self.att0(state[0][-1], att_feats, p_att_feats, att_masks)
711
+ h_0, state_0 = self.lstm0(torch.cat([xt,fc_feats],1), [state[0][0:1], state[1][0:1]])
712
+ att_res_1 = self.att1(h_0, att_feats, p_att_feats, att_masks)
713
+ h_1, state_1 = self.lstm1(torch.cat([h_0,att_res_1],1), [state[0][1:2], state[1][1:2]])
714
+ att_res_2 = self.att2(h_1 + self.emb2(att_res_1), att_feats, p_att_feats, att_masks)
715
+ h_2, state_2 = self.lstm2(torch.cat([self.fusion1(torch.cat([h_0, h_1], 1)),att_res_2],1), [state[0][2:3], state[1][2:3]])
716
+
717
+ return self.fusion2(torch.cat([h_0, h_1, h_2], 1)), [torch.cat(_, 0) for _ in zip(state_0, state_1, state_2)]
718
+
719
+ class Attention(nn.Module):
720
+ def __init__(self, opt):
721
+ super(Attention, self).__init__()
722
+ self.rnn_size = opt.rnn_size
723
+ self.att_hid_size = opt.att_hid_size
724
+
725
+ self.h2att = nn.Linear(self.rnn_size, self.att_hid_size)
726
+ self.alpha_net = nn.Linear(self.att_hid_size, 1)
727
+
728
+ def forward(self, h, att_feats, p_att_feats, att_masks=None):
729
+ # The p_att_feats here is already projected
730
+ att_size = att_feats.numel() // att_feats.size(0) // att_feats.size(-1)
731
+ att = p_att_feats.view(-1, att_size, self.att_hid_size)
732
+
733
+ att_h = self.h2att(h) # batch * att_hid_size
734
+ att_h = att_h.unsqueeze(1).expand_as(att) # batch * att_size * att_hid_size
735
+ dot = att + att_h # batch * att_size * att_hid_size
736
+ dot = torch.tanh(dot) # batch * att_size * att_hid_size
737
+ dot = dot.view(-1, self.att_hid_size) # (batch * att_size) * att_hid_size
738
+ dot = self.alpha_net(dot) # (batch * att_size) * 1
739
+ dot = dot.view(-1, att_size) # batch * att_size
740
+
741
+ weight = F.softmax(dot, dim=1) # batch * att_size
742
+ if att_masks is not None:
743
+ weight = weight * att_masks.view(-1, att_size).to(weight)
744
+ weight = weight / weight.sum(1, keepdim=True) # normalize to 1
745
+ att_feats_ = att_feats.view(-1, att_size, att_feats.size(-1)) # batch * att_size * att_feat_size
746
+ att_res = torch.bmm(weight.unsqueeze(1), att_feats_).squeeze(1) # batch * att_feat_size
747
+
748
+ return att_res
749
+
750
+ class Att2in2Core(nn.Module):
751
+ def __init__(self, opt):
752
+ super(Att2in2Core, self).__init__()
753
+ self.input_encoding_size = opt.input_encoding_size
754
+ #self.rnn_type = opt.rnn_type
755
+ self.rnn_size = opt.rnn_size
756
+ #self.num_layers = opt.num_layers
757
+ self.drop_prob_lm = opt.drop_prob_lm
758
+ self.fc_feat_size = opt.fc_feat_size
759
+ self.att_feat_size = opt.att_feat_size
760
+ self.att_hid_size = opt.att_hid_size
761
+
762
+ # Build a LSTM
763
+ self.a2c = nn.Linear(self.rnn_size, 2 * self.rnn_size)
764
+ self.i2h = nn.Linear(self.input_encoding_size, 5 * self.rnn_size)
765
+ self.h2h = nn.Linear(self.rnn_size, 5 * self.rnn_size)
766
+ self.dropout = nn.Dropout(self.drop_prob_lm)
767
+
768
+ self.attention = Attention(opt)
769
+
770
+ def forward(self, xt, fc_feats, att_feats, p_att_feats, state, att_masks=None):
771
+ att_res = self.attention(state[0][-1], att_feats, p_att_feats, att_masks)
772
+
773
+ all_input_sums = self.i2h(xt) + self.h2h(state[0][-1])
774
+ sigmoid_chunk = all_input_sums.narrow(1, 0, 3 * self.rnn_size)
775
+ sigmoid_chunk = torch.sigmoid(sigmoid_chunk)
776
+ in_gate = sigmoid_chunk.narrow(1, 0, self.rnn_size)
777
+ forget_gate = sigmoid_chunk.narrow(1, self.rnn_size, self.rnn_size)
778
+ out_gate = sigmoid_chunk.narrow(1, self.rnn_size * 2, self.rnn_size)
779
+
780
+ in_transform = all_input_sums.narrow(1, 3 * self.rnn_size, 2 * self.rnn_size) + \
781
+ self.a2c(att_res)
782
+ in_transform = torch.max(\
783
+ in_transform.narrow(1, 0, self.rnn_size),
784
+ in_transform.narrow(1, self.rnn_size, self.rnn_size))
785
+ next_c = forget_gate * state[1][-1] + in_gate * in_transform
786
+ next_h = out_gate * torch.tanh(next_c)
787
+
788
+ output = self.dropout(next_h)
789
+ state = (next_h.unsqueeze(0), next_c.unsqueeze(0))
790
+ return output, state
791
+
792
+ class Att2inCore(Att2in2Core):
793
+ def __init__(self, opt):
794
+ super(Att2inCore, self).__init__(opt)
795
+ del self.a2c
796
+ self.a2c = nn.Linear(self.att_feat_size, 2 * self.rnn_size)
797
+
798
+ """
799
+ Note this is my attempt to replicate att2all model in self-critical paper.
800
+ However, this is not a correct replication actually. Will fix it.
801
+ """
802
+ class Att2all2Core(nn.Module):
803
+ def __init__(self, opt):
804
+ super(Att2all2Core, self).__init__()
805
+ self.input_encoding_size = opt.input_encoding_size
806
+ #self.rnn_type = opt.rnn_type
807
+ self.rnn_size = opt.rnn_size
808
+ #self.num_layers = opt.num_layers
809
+ self.drop_prob_lm = opt.drop_prob_lm
810
+ self.fc_feat_size = opt.fc_feat_size
811
+ self.att_feat_size = opt.att_feat_size
812
+ self.att_hid_size = opt.att_hid_size
813
+
814
+ # Build a LSTM
815
+ self.a2h = nn.Linear(self.rnn_size, 5 * self.rnn_size)
816
+ self.i2h = nn.Linear(self.input_encoding_size, 5 * self.rnn_size)
817
+ self.h2h = nn.Linear(self.rnn_size, 5 * self.rnn_size)
818
+ self.dropout = nn.Dropout(self.drop_prob_lm)
819
+
820
+ self.attention = Attention(opt)
821
+
822
+ def forward(self, xt, fc_feats, att_feats, p_att_feats, state, att_masks=None):
823
+ att_res = self.attention(state[0][-1], att_feats, p_att_feats, att_masks)
824
+
825
+ all_input_sums = self.i2h(xt) + self.h2h(state[0][-1]) + self.a2h(att_res)
826
+ sigmoid_chunk = all_input_sums.narrow(1, 0, 3 * self.rnn_size)
827
+ sigmoid_chunk = torch.sigmoid(sigmoid_chunk)
828
+ in_gate = sigmoid_chunk.narrow(1, 0, self.rnn_size)
829
+ forget_gate = sigmoid_chunk.narrow(1, self.rnn_size, self.rnn_size)
830
+ out_gate = sigmoid_chunk.narrow(1, self.rnn_size * 2, self.rnn_size)
831
+
832
+ in_transform = all_input_sums.narrow(1, 3 * self.rnn_size, 2 * self.rnn_size)
833
+ in_transform = torch.max(\
834
+ in_transform.narrow(1, 0, self.rnn_size),
835
+ in_transform.narrow(1, self.rnn_size, self.rnn_size))
836
+ next_c = forget_gate * state[1][-1] + in_gate * in_transform
837
+ next_h = out_gate * torch.tanh(next_c)
838
+
839
+ output = self.dropout(next_h)
840
+ state = (next_h.unsqueeze(0), next_c.unsqueeze(0))
841
+ return output, state
842
+
843
+ class AdaAttModel(AttModel):
844
+ def __init__(self, opt):
845
+ super(AdaAttModel, self).__init__(opt)
846
+ self.core = AdaAttCore(opt)
847
+
848
+ # AdaAtt with maxout lstm
849
+ class AdaAttMOModel(AttModel):
850
+ def __init__(self, opt):
851
+ super(AdaAttMOModel, self).__init__(opt)
852
+ self.core = AdaAttCore(opt, True)
853
+
854
+ class Att2in2Model(AttModel):
855
+ def __init__(self, opt):
856
+ super(Att2in2Model, self).__init__(opt)
857
+ self.core = Att2in2Core(opt)
858
+ delattr(self, 'fc_embed')
859
+ self.fc_embed = lambda x : x
860
+
861
+ class Att2all2Model(AttModel):
862
+ def __init__(self, opt):
863
+ super(Att2all2Model, self).__init__(opt)
864
+ self.core = Att2all2Core(opt)
865
+ delattr(self, 'fc_embed')
866
+ self.fc_embed = lambda x : x
867
+
868
+ class UpDownModel(AttModel):
869
+ def __init__(self, opt):
870
+ super(UpDownModel, self).__init__(opt)
871
+ self.num_layers = 2
872
+ self.core = UpDownCore(opt)
873
+
874
+ class StackAttModel(AttModel):
875
+ def __init__(self, opt):
876
+ super(StackAttModel, self).__init__(opt)
877
+ self.num_layers = 3
878
+ self.core = StackAttCore(opt)
879
+
880
+ class DenseAttModel(AttModel):
881
+ def __init__(self, opt):
882
+ super(DenseAttModel, self).__init__(opt)
883
+ self.num_layers = 3
884
+ self.core = DenseAttCore(opt)
885
+
886
+ class Att2inModel(AttModel):
887
+ def __init__(self, opt):
888
+ super(Att2inModel, self).__init__(opt)
889
+ del self.embed, self.fc_embed, self.att_embed
890
+ self.embed = nn.Embedding(self.vocab_size + 1, self.input_encoding_size)
891
+ self.fc_embed = self.att_embed = lambda x: x
892
+ del self.ctx2att
893
+ self.ctx2att = nn.Linear(self.att_feat_size, self.att_hid_size)
894
+ self.core = Att2inCore(opt)
895
+ self.init_weights()
896
+
897
+ def init_weights(self):
898
+ initrange = 0.1
899
+ self.embed.weight.data.uniform_(-initrange, initrange)
900
+ self.logit.bias.data.fill_(0)
901
+ self.logit.weight.data.uniform_(-initrange, initrange)
902
+
903
+
904
+ class NewFCModel(AttModel):
905
+ def __init__(self, opt):
906
+ super(NewFCModel, self).__init__(opt)
907
+ self.fc_embed = nn.Linear(self.fc_feat_size, self.input_encoding_size)
908
+ self.embed = nn.Embedding(self.vocab_size + 1, self.input_encoding_size)
909
+ self._core = LSTMCore(opt)
910
+ delattr(self, 'att_embed')
911
+ self.att_embed = lambda x : x
912
+ delattr(self, 'ctx2att')
913
+ self.ctx2att = lambda x: x
914
+
915
+ def core(self, xt, fc_feats, att_feats, p_att_feats, state, att_masks):
916
+ # Step 0, feed the input image
917
+ # if (self.training and state[0].is_leaf) or \
918
+ # (not self.training and state[0].sum() == 0):
919
+ # _, state = self._core(fc_feats, state)
920
+ # three cases
921
+ # normal mle training
922
+ # Sample
923
+ # beam search (diverse beam search)
924
+ # fixed captioning module.
925
+ is_first_step = (state[0]==0).all(2).all(0) # size: B
926
+ if is_first_step.all():
927
+ _, state = self._core(fc_feats, state)
928
+ elif is_first_step.any():
929
+ # This is mostly for diverse beam search I think
930
+ new_state = [torch.zeros_like(_) for _ in state]
931
+ new_state[0][:, ~is_first_step] = state[0][:, ~is_first_step]
932
+ new_state[1][:, ~is_first_step] = state[1][:, ~is_first_step]
933
+ _, state = self._core(fc_feats, state)
934
+ new_state[0][:, is_first_step] = state[0][:, is_first_step]
935
+ new_state[1][:, is_first_step] = state[1][:, is_first_step]
936
+ state = new_state
937
+ # if (state[0]==0).all():
938
+ # # Let's forget about diverse beam search first
939
+ # _, state = self._core(fc_feats, state)
940
+ return self._core(xt, state)
941
+
942
+ def _prepare_feature(self, fc_feats, att_feats, att_masks):
943
+ fc_feats = self.fc_embed(fc_feats)
944
+
945
+ return fc_feats, att_feats, att_feats, att_masks
946
+
947
+
948
+ class LMModel(AttModel):
949
+ def __init__(self, opt):
950
+ super(LMModel, self).__init__(opt)
951
+ delattr(self, 'fc_embed')
952
+ self.fc_embed = lambda x: x.new_zeros(x.shape[0], self.input_encoding_size)
953
+ self.embed = nn.Embedding(self.vocab_size + 1, self.input_encoding_size)
954
+ self._core = LSTMCore(opt)
955
+ delattr(self, 'att_embed')
956
+ self.att_embed = lambda x : x
957
+ delattr(self, 'ctx2att')
958
+ self.ctx2att = lambda x: x
959
+
960
+ def core(self, xt, fc_feats, att_feats, p_att_feats, state, att_masks):
961
+ if (state[0]==0).all():
962
+ # Let's forget about diverse beam search first
963
+ _, state = self._core(fc_feats, state)
964
+ return self._core(xt, state)
965
+
966
+ def _prepare_feature(self, fc_feats, att_feats, att_masks):
967
+ fc_feats = self.fc_embed(fc_feats)
968
+
969
+ return fc_feats, None, None, None
captioning/models/BertCapModel.py ADDED
@@ -0,0 +1,104 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ BertCapModel is using huggingface transformer bert model as seq2seq model.
3
+
4
+ The result is not as goog as original transformer.
5
+ """
6
+
7
+ from __future__ import absolute_import
8
+ from __future__ import division
9
+ from __future__ import print_function
10
+
11
+ import torch
12
+ import torch.nn as nn
13
+ import torch.nn.functional as F
14
+
15
+ import copy
16
+ import math
17
+ import numpy as np
18
+
19
+ from .CaptionModel import CaptionModel
20
+ from .AttModel import sort_pack_padded_sequence, pad_unsort_packed_sequence, pack_wrapper, AttModel
21
+ try:
22
+ from transformers import BertModel, BertConfig
23
+ except:
24
+ print('Hugginface transformers not installed; please visit https://github.com/huggingface/transformers')
25
+ from .TransformerModel import subsequent_mask, TransformerModel, Generator
26
+
27
+ class EncoderDecoder(nn.Module):
28
+ """
29
+ A standard Encoder-Decoder architecture. Base for this and many
30
+ other models.
31
+ """
32
+ def __init__(self, encoder, decoder, generator):
33
+ super(EncoderDecoder, self).__init__()
34
+ self.encoder = encoder
35
+ self.decoder = decoder
36
+ self.generator = generator
37
+
38
+ def forward(self, src, tgt, src_mask, tgt_mask):
39
+ "Take in and process masked src and target sequences."
40
+ return self.decode(self.encode(src, src_mask), src_mask,
41
+ tgt, tgt_mask)
42
+
43
+ def encode(self, src, src_mask):
44
+ return self.encoder(inputs_embeds=src,
45
+ attention_mask=src_mask)[0]
46
+
47
+ def decode(self, memory, src_mask, tgt, tgt_mask):
48
+ return self.decoder(input_ids=tgt,
49
+ attention_mask=tgt_mask,
50
+ encoder_hidden_states=memory,
51
+ encoder_attention_mask=src_mask)[0]
52
+
53
+
54
+ class BertCapModel(TransformerModel):
55
+
56
+ def make_model(self, src_vocab, tgt_vocab, N_enc=6, N_dec=6,
57
+ d_model=512, d_ff=2048, h=8, dropout=0.1):
58
+ "Helper: Construct a model from hyperparameters."
59
+ enc_config = BertConfig(vocab_size=1,
60
+ hidden_size=d_model,
61
+ num_hidden_layers=N_enc,
62
+ num_attention_heads=h,
63
+ intermediate_size=d_ff,
64
+ hidden_dropout_prob=dropout,
65
+ attention_probs_dropout_prob=dropout,
66
+ max_position_embeddings=1,
67
+ type_vocab_size=1)
68
+ dec_config = BertConfig(vocab_size=tgt_vocab,
69
+ hidden_size=d_model,
70
+ num_hidden_layers=N_dec,
71
+ num_attention_heads=h,
72
+ intermediate_size=d_ff,
73
+ hidden_dropout_prob=dropout,
74
+ attention_probs_dropout_prob=dropout,
75
+ max_position_embeddings=17,
76
+ type_vocab_size=1,
77
+ is_decoder=True)
78
+ encoder = BertModel(enc_config)
79
+ def return_embeds(*args, **kwargs):
80
+ return kwargs['inputs_embeds']
81
+ del encoder.embeddings; encoder.embeddings = return_embeds
82
+ decoder = BertModel(dec_config)
83
+ model = EncoderDecoder(
84
+ encoder,
85
+ decoder,
86
+ Generator(d_model, tgt_vocab))
87
+ return model
88
+
89
+ def __init__(self, opt):
90
+ super(BertCapModel, self).__init__(opt)
91
+
92
+ def core(self, it, fc_feats_ph, att_feats_ph, memory, state, mask):
93
+ """
94
+ state = [ys.unsqueeze(0)]
95
+ """
96
+ if len(state) == 0:
97
+ ys = it.unsqueeze(1)
98
+ else:
99
+ ys = torch.cat([state[0][0], it.unsqueeze(1)], dim=1)
100
+ out = self.model.decode(memory, mask,
101
+ ys,
102
+ subsequent_mask(ys.size(1))
103
+ .to(memory.device))
104
+ return out[:, -1], [ys.unsqueeze(0)]
captioning/models/CaptionModel.py ADDED
@@ -0,0 +1,407 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # This file contains ShowAttendTell and AllImg model
2
+
3
+ # ShowAttendTell is from Show, Attend and Tell: Neural Image Caption Generation with Visual Attention
4
+ # https://arxiv.org/abs/1502.03044
5
+
6
+ # AllImg is a model where
7
+ # img feature is concatenated with word embedding at every time step as the input of lstm
8
+ from __future__ import absolute_import
9
+ from __future__ import division
10
+ from __future__ import print_function
11
+
12
+ import numpy as np
13
+ import torch
14
+ import torch.nn as nn
15
+ import torch.nn.functional as F
16
+ from torch.autograd import *
17
+ from ..utils import misc as utils
18
+ from . import utils as model_utils
19
+
20
+
21
+ class CaptionModel(nn.Module):
22
+ def __init__(self):
23
+ super(CaptionModel, self).__init__()
24
+
25
+ # implements beam search
26
+ # calls beam_step and returns the final set of beams
27
+ # augments log-probabilities with diversity terms when number of groups > 1
28
+
29
+ def forward(self, *args, **kwargs):
30
+ mode = kwargs.get('mode', 'forward')
31
+ if 'mode' in kwargs:
32
+ del kwargs['mode']
33
+ return getattr(self, '_'+mode)(*args, **kwargs)
34
+
35
+ def beam_search(self, init_state, init_logprobs, *args, **kwargs):
36
+
37
+ # function computes the similarity score to be augmented
38
+ def add_diversity(beam_seq_table, logprobs, t, divm, diversity_lambda, bdash):
39
+ local_time = t - divm
40
+ unaug_logprobs = logprobs.clone()
41
+ batch_size = beam_seq_table[0].shape[0]
42
+
43
+ if divm > 0:
44
+ change = logprobs.new_zeros(batch_size, logprobs.shape[-1])
45
+ for prev_choice in range(divm):
46
+ prev_decisions = beam_seq_table[prev_choice][:, :, local_time] # Nxb
47
+ for prev_labels in range(bdash):
48
+ change.scatter_add_(1, prev_decisions[:, prev_labels].unsqueeze(-1), change.new_ones(batch_size, 1))
49
+
50
+ if local_time == 0:
51
+ logprobs = logprobs - change * diversity_lambda
52
+ else:
53
+ logprobs = logprobs - self.repeat_tensor(bdash, change) * diversity_lambda
54
+
55
+ return logprobs, unaug_logprobs
56
+
57
+
58
+ # does one step of classical beam search
59
+
60
+ def beam_step(logprobs, unaug_logprobs, beam_size, t, beam_seq, beam_seq_logprobs, beam_logprobs_sum, state):
61
+ #INPUTS:
62
+ #logprobs: probabilities augmented after diversity N*bxV
63
+ #beam_size: obvious
64
+ #t : time instant
65
+ #beam_seq : tensor contanining the beams
66
+ #beam_seq_logprobs: tensor contanining the beam logprobs
67
+ #beam_logprobs_sum: tensor contanining joint logprobs
68
+ #OUPUTS:
69
+ #beam_seq : tensor containing the word indices of the decoded captions Nxbxl
70
+ #beam_seq_logprobs : log-probability of each decision made, NxbxlxV
71
+ #beam_logprobs_sum : joint log-probability of each beam Nxb
72
+
73
+ batch_size = beam_logprobs_sum.shape[0]
74
+ vocab_size = logprobs.shape[-1]
75
+ logprobs = logprobs.reshape(batch_size, -1, vocab_size) # NxbxV
76
+ if t == 0:
77
+ assert logprobs.shape[1] == 1
78
+ beam_logprobs_sum = beam_logprobs_sum[:, :1]
79
+ candidate_logprobs = beam_logprobs_sum.unsqueeze(-1) + logprobs # beam_logprobs_sum Nxb logprobs is NxbxV
80
+ ys, ix = torch.sort(candidate_logprobs.reshape(candidate_logprobs.shape[0], -1), -1, True)
81
+ ys, ix = ys[:,:beam_size], ix[:,:beam_size]
82
+ beam_ix = ix // vocab_size # Nxb which beam
83
+ selected_ix = ix % vocab_size # Nxb # which world
84
+ state_ix = (beam_ix + torch.arange(batch_size).type_as(beam_ix).unsqueeze(-1) * logprobs.shape[1]).reshape(-1) # N*b which in Nxb beams
85
+
86
+
87
+ if t > 0:
88
+ # gather according to beam_ix
89
+ assert (beam_seq.gather(1, beam_ix.unsqueeze(-1).expand_as(beam_seq)) == beam_seq.reshape(-1, beam_seq.shape[-1])[state_ix].view_as(beam_seq)).all()
90
+ beam_seq = beam_seq.gather(1, beam_ix.unsqueeze(-1).expand_as(beam_seq))
91
+
92
+ beam_seq_logprobs = beam_seq_logprobs.gather(1, beam_ix.unsqueeze(-1).unsqueeze(-1).expand_as(beam_seq_logprobs))
93
+
94
+ beam_seq = torch.cat([beam_seq, selected_ix.unsqueeze(-1)], -1) # beam_seq Nxbxl
95
+ beam_logprobs_sum = beam_logprobs_sum.gather(1, beam_ix) + \
96
+ logprobs.reshape(batch_size, -1).gather(1, ix)
97
+ assert (beam_logprobs_sum == ys).all()
98
+ _tmp_beam_logprobs = unaug_logprobs[state_ix].reshape(batch_size, -1, vocab_size)
99
+ beam_logprobs = unaug_logprobs.reshape(batch_size, -1, vocab_size).gather(1, beam_ix.unsqueeze(-1).expand(-1, -1, vocab_size)) # NxbxV
100
+ assert (_tmp_beam_logprobs == beam_logprobs).all()
101
+ beam_seq_logprobs = torch.cat([
102
+ beam_seq_logprobs,
103
+ beam_logprobs.reshape(batch_size, -1, 1, vocab_size)], 2)
104
+
105
+ new_state = [None for _ in state]
106
+ for _ix in range(len(new_state)):
107
+ # copy over state in previous beam q to new beam at vix
108
+ new_state[_ix] = state[_ix][:, state_ix]
109
+ state = new_state
110
+ return beam_seq,beam_seq_logprobs,beam_logprobs_sum,state
111
+
112
+ # Start diverse_beam_search
113
+ opt = kwargs['opt']
114
+ temperature = opt.get('temperature', 1) # This should not affect beam search, but will affect dbs
115
+ beam_size = opt.get('beam_size', 10)
116
+ group_size = opt.get('group_size', 1)
117
+ diversity_lambda = opt.get('diversity_lambda', 0.5)
118
+ decoding_constraint = opt.get('decoding_constraint', 0)
119
+ remove_bad_endings = opt.get('remove_bad_endings', 0)
120
+ suppress_UNK = opt.get('suppress_UNK', 0)
121
+ length_penalty = utils.penalty_builder(opt.get('length_penalty', ''))
122
+ bdash = beam_size // group_size # beam per group
123
+
124
+ batch_size = init_logprobs.shape[0]
125
+ device = init_logprobs.device
126
+ # INITIALIZATIONS
127
+ beam_seq_table = [torch.LongTensor(batch_size, bdash, 0).to(device) for _ in range(group_size)]
128
+ beam_seq_logprobs_table = [torch.FloatTensor(batch_size, bdash, 0, self.vocab_size + 1).to(device) for _ in range(group_size)]
129
+ beam_logprobs_sum_table = [torch.zeros(batch_size, bdash).to(device) for _ in range(group_size)]
130
+
131
+ # logprobs # logprobs predicted in last time step, shape (beam_size, vocab_size+1)
132
+ done_beams_table = [[[] for __ in range(group_size)] for _ in range(batch_size)]
133
+ # state_table = [list(torch.unbind(_)) for _ in torch.stack(init_state).chunk(group_size, 2)]
134
+ # state_table = list(zip(*[_.reshape(-1, batch_size * bdash, group_size, *_.shape[2:]).chunk(group_size, 2) for _ in init_state]))
135
+ state_table = [[_.clone() for _ in init_state] for _ in range(group_size)]
136
+ # logprobs_table = list(init_logprobs.reshape(batch_size * bdash, group_size, -1).chunk(group_size, 0))
137
+ logprobs_table = [init_logprobs.clone() for _ in range(group_size)]
138
+ # END INIT
139
+
140
+ # Chunk elements in the args
141
+ args = list(args)
142
+ args = model_utils.split_tensors(group_size, args) # For each arg, turn (Bbg)x... to (Bb)x(g)x...
143
+ if self.__class__.__name__ == 'AttEnsemble':
144
+ args = [[[args[j][i][k] for i in range(len(self.models))] for j in range(len(args))] for k in range(group_size)] # group_name, arg_name, model_name
145
+ else:
146
+ args = [[args[i][j] for i in range(len(args))] for j in range(group_size)]
147
+
148
+ for t in range(self.seq_length + group_size - 1):
149
+ for divm in range(group_size):
150
+ if t >= divm and t <= self.seq_length + divm - 1:
151
+ # add diversity
152
+ logprobs = logprobs_table[divm]
153
+ # suppress previous word
154
+ if decoding_constraint and t-divm > 0:
155
+ logprobs.scatter_(1, beam_seq_table[divm][:, :, t-divm-1].reshape(-1, 1).to(device), float('-inf'))
156
+ if remove_bad_endings and t-divm > 0:
157
+ logprobs[torch.from_numpy(np.isin(beam_seq_table[divm][:, :, t-divm-1].cpu().numpy(), self.bad_endings_ix)).reshape(-1), 0] = float('-inf')
158
+ # suppress UNK tokens in the decoding
159
+ if suppress_UNK and hasattr(self, 'vocab') and self.vocab[str(logprobs.size(1)-1)] == 'UNK':
160
+ logprobs[:,logprobs.size(1)-1] = logprobs[:, logprobs.size(1)-1] - 1000
161
+ # diversity is added here
162
+ # the function directly modifies the logprobs values and hence, we need to return
163
+ # the unaugmented ones for sorting the candidates in the end. # for historical
164
+ # reasons :-)
165
+ logprobs, unaug_logprobs = add_diversity(beam_seq_table,logprobs,t,divm,diversity_lambda,bdash)
166
+
167
+ # infer new beams
168
+ beam_seq_table[divm],\
169
+ beam_seq_logprobs_table[divm],\
170
+ beam_logprobs_sum_table[divm],\
171
+ state_table[divm] = beam_step(logprobs,
172
+ unaug_logprobs,
173
+ bdash,
174
+ t-divm,
175
+ beam_seq_table[divm],
176
+ beam_seq_logprobs_table[divm],
177
+ beam_logprobs_sum_table[divm],
178
+ state_table[divm])
179
+
180
+ # if time's up... or if end token is reached then copy beams
181
+ for b in range(batch_size):
182
+ is_end = beam_seq_table[divm][b, :, t-divm] == self.eos_idx
183
+ assert beam_seq_table[divm].shape[-1] == t-divm+1
184
+ if t == self.seq_length + divm - 1:
185
+ is_end.fill_(1)
186
+ for vix in range(bdash):
187
+ if is_end[vix]:
188
+ final_beam = {
189
+ 'seq': beam_seq_table[divm][b, vix].clone(),
190
+ 'logps': beam_seq_logprobs_table[divm][b, vix].clone(),
191
+ 'unaug_p': beam_seq_logprobs_table[divm][b, vix].sum().item(),
192
+ 'p': beam_logprobs_sum_table[divm][b, vix].item()
193
+ }
194
+ final_beam['p'] = length_penalty(t-divm+1, final_beam['p'])
195
+ done_beams_table[b][divm].append(final_beam)
196
+ beam_logprobs_sum_table[divm][b, is_end] -= 1000
197
+
198
+ # move the current group one step forward in time
199
+
200
+ it = beam_seq_table[divm][:, :, t-divm].reshape(-1).to(logprobs.device)
201
+ logprobs_table[divm], state_table[divm] = self.get_logprobs_state(it, *(args[divm] + [state_table[divm]]))
202
+ logprobs_table[divm] = F.log_softmax(logprobs_table[divm] / temperature, dim=-1)
203
+
204
+ # all beams are sorted by their log-probabilities
205
+ done_beams_table = [[sorted(done_beams_table[b][i], key=lambda x: -x['p'])[:bdash] for i in range(group_size)] for b in range(batch_size)]
206
+ done_beams = [sum(_, []) for _ in done_beams_table]
207
+ return done_beams
208
+
209
+ def old_beam_search(self, init_state, init_logprobs, *args, **kwargs):
210
+
211
+ # function computes the similarity score to be augmented
212
+ def add_diversity(beam_seq_table, logprobsf, t, divm, diversity_lambda, bdash):
213
+ local_time = t - divm
214
+ unaug_logprobsf = logprobsf.clone()
215
+ for prev_choice in range(divm):
216
+ prev_decisions = beam_seq_table[prev_choice][local_time]
217
+ for sub_beam in range(bdash):
218
+ for prev_labels in range(bdash):
219
+ logprobsf[sub_beam][prev_decisions[prev_labels]] = logprobsf[sub_beam][prev_decisions[prev_labels]] - diversity_lambda
220
+ return unaug_logprobsf
221
+
222
+ # does one step of classical beam search
223
+
224
+ def beam_step(logprobsf, unaug_logprobsf, beam_size, t, beam_seq, beam_seq_logprobs, beam_logprobs_sum, state):
225
+ #INPUTS:
226
+ #logprobsf: probabilities augmented after diversity
227
+ #beam_size: obvious
228
+ #t : time instant
229
+ #beam_seq : tensor contanining the beams
230
+ #beam_seq_logprobs: tensor contanining the beam logprobs
231
+ #beam_logprobs_sum: tensor contanining joint logprobs
232
+ #OUPUTS:
233
+ #beam_seq : tensor containing the word indices of the decoded captions
234
+ #beam_seq_logprobs : log-probability of each decision made, same size as beam_seq
235
+ #beam_logprobs_sum : joint log-probability of each beam
236
+
237
+ ys,ix = torch.sort(logprobsf,1,True)
238
+ candidates = []
239
+ cols = min(beam_size, ys.size(1))
240
+ rows = beam_size
241
+ if t == 0:
242
+ rows = 1
243
+ for c in range(cols): # for each column (word, essentially)
244
+ for q in range(rows): # for each beam expansion
245
+ #compute logprob of expanding beam q with word in (sorted) position c
246
+ local_logprob = ys[q,c].item()
247
+ candidate_logprob = beam_logprobs_sum[q] + local_logprob
248
+ # local_unaug_logprob = unaug_logprobsf[q,ix[q,c]]
249
+ candidates.append({'c':ix[q,c], 'q':q, 'p':candidate_logprob, 'r':unaug_logprobsf[q]})
250
+ candidates = sorted(candidates, key=lambda x: -x['p'])
251
+
252
+ new_state = [_.clone() for _ in state]
253
+ #beam_seq_prev, beam_seq_logprobs_prev
254
+ if t >= 1:
255
+ #we''ll need these as reference when we fork beams around
256
+ beam_seq_prev = beam_seq[:t].clone()
257
+ beam_seq_logprobs_prev = beam_seq_logprobs[:t].clone()
258
+ for vix in range(beam_size):
259
+ v = candidates[vix]
260
+ #fork beam index q into index vix
261
+ if t >= 1:
262
+ beam_seq[:t, vix] = beam_seq_prev[:, v['q']]
263
+ beam_seq_logprobs[:t, vix] = beam_seq_logprobs_prev[:, v['q']]
264
+ #rearrange recurrent states
265
+ for state_ix in range(len(new_state)):
266
+ # copy over state in previous beam q to new beam at vix
267
+ new_state[state_ix][:, vix] = state[state_ix][:, v['q']] # dimension one is time step
268
+ #append new end terminal at the end of this beam
269
+ beam_seq[t, vix] = v['c'] # c'th word is the continuation
270
+ beam_seq_logprobs[t, vix] = v['r'] # the raw logprob here
271
+ beam_logprobs_sum[vix] = v['p'] # the new (sum) logprob along this beam
272
+ state = new_state
273
+ return beam_seq,beam_seq_logprobs,beam_logprobs_sum,state,candidates
274
+
275
+ # Start diverse_beam_search
276
+ opt = kwargs['opt']
277
+ temperature = opt.get('temperature', 1) # This should not affect beam search, but will affect dbs
278
+ beam_size = opt.get('beam_size', 10)
279
+ group_size = opt.get('group_size', 1)
280
+ diversity_lambda = opt.get('diversity_lambda', 0.5)
281
+ decoding_constraint = opt.get('decoding_constraint', 0)
282
+ remove_bad_endings = opt.get('remove_bad_endings', 0)
283
+ suppress_UNK = opt.get('suppress_UNK', 0)
284
+ length_penalty = utils.penalty_builder(opt.get('length_penalty', ''))
285
+ bdash = beam_size // group_size # beam per group
286
+
287
+ # INITIALIZATIONS
288
+ beam_seq_table = [torch.LongTensor(self.seq_length, bdash).zero_() for _ in range(group_size)]
289
+ beam_seq_logprobs_table = [torch.FloatTensor(self.seq_length, bdash, self.vocab_size + 1).zero_() for _ in range(group_size)]
290
+ beam_logprobs_sum_table = [torch.zeros(bdash) for _ in range(group_size)]
291
+
292
+ # logprobs # logprobs predicted in last time step, shape (beam_size, vocab_size+1)
293
+ done_beams_table = [[] for _ in range(group_size)]
294
+ # state_table = [list(torch.unbind(_)) for _ in torch.stack(init_state).chunk(group_size, 2)]
295
+ state_table = list(zip(*[_.chunk(group_size, 1) for _ in init_state]))
296
+ logprobs_table = list(init_logprobs.chunk(group_size, 0))
297
+ # END INIT
298
+
299
+ # Chunk elements in the args
300
+ args = list(args)
301
+ if self.__class__.__name__ == 'AttEnsemble':
302
+ args = [[_.chunk(group_size) if _ is not None else [None]*group_size for _ in args_] for args_ in args] # arg_name, model_name, group_name
303
+ args = [[[args[j][i][k] for i in range(len(self.models))] for j in range(len(args))] for k in range(group_size)] # group_name, arg_name, model_name
304
+ else:
305
+ args = [_.chunk(group_size) if _ is not None else [None]*group_size for _ in args]
306
+ args = [[args[i][j] for i in range(len(args))] for j in range(group_size)]
307
+
308
+ for t in range(self.seq_length + group_size - 1):
309
+ for divm in range(group_size):
310
+ if t >= divm and t <= self.seq_length + divm - 1:
311
+ # add diversity
312
+ logprobsf = logprobs_table[divm]
313
+ # suppress previous word
314
+ if decoding_constraint and t-divm > 0:
315
+ logprobsf.scatter_(1, beam_seq_table[divm][t-divm-1].unsqueeze(1).to(logprobsf.device), float('-inf'))
316
+ if remove_bad_endings and t-divm > 0:
317
+ logprobsf[torch.from_numpy(np.isin(beam_seq_table[divm][t-divm-1].cpu().numpy(), self.bad_endings_ix)), 0] = float('-inf')
318
+ # suppress UNK tokens in the decoding
319
+ if suppress_UNK and hasattr(self, 'vocab') and self.vocab[str(logprobsf.size(1)-1)] == 'UNK':
320
+ logprobsf[:,logprobsf.size(1)-1] = logprobsf[:, logprobsf.size(1)-1] - 1000
321
+ # diversity is added here
322
+ # the function directly modifies the logprobsf values and hence, we need to return
323
+ # the unaugmented ones for sorting the candidates in the end. # for historical
324
+ # reasons :-)
325
+ unaug_logprobsf = add_diversity(beam_seq_table,logprobsf,t,divm,diversity_lambda,bdash)
326
+
327
+ # infer new beams
328
+ beam_seq_table[divm],\
329
+ beam_seq_logprobs_table[divm],\
330
+ beam_logprobs_sum_table[divm],\
331
+ state_table[divm],\
332
+ candidates_divm = beam_step(logprobsf,
333
+ unaug_logprobsf,
334
+ bdash,
335
+ t-divm,
336
+ beam_seq_table[divm],
337
+ beam_seq_logprobs_table[divm],
338
+ beam_logprobs_sum_table[divm],
339
+ state_table[divm])
340
+
341
+ # if time's up... or if end token is reached then copy beams
342
+ for vix in range(bdash):
343
+ if beam_seq_table[divm][t-divm,vix] == self.eos_idx or t == self.seq_length + divm - 1:
344
+ final_beam = {
345
+ 'seq': beam_seq_table[divm][:, vix].clone(),
346
+ 'logps': beam_seq_logprobs_table[divm][:, vix].clone(),
347
+ 'unaug_p': beam_seq_logprobs_table[divm][:, vix].sum().item(),
348
+ 'p': beam_logprobs_sum_table[divm][vix].item()
349
+ }
350
+ final_beam['p'] = length_penalty(t-divm+1, final_beam['p'])
351
+ done_beams_table[divm].append(final_beam)
352
+ # don't continue beams from finished sequences
353
+ beam_logprobs_sum_table[divm][vix] = -1000
354
+
355
+ # move the current group one step forward in time
356
+
357
+ it = beam_seq_table[divm][t-divm].to(logprobsf.device)
358
+ logprobs_table[divm], state_table[divm] = self.get_logprobs_state(it, *(args[divm] + [state_table[divm]]))
359
+ logprobs_table[divm] = F.log_softmax(logprobs_table[divm] / temperature, dim=-1)
360
+
361
+ # all beams are sorted by their log-probabilities
362
+ done_beams_table = [sorted(done_beams_table[i], key=lambda x: -x['p'])[:bdash] for i in range(group_size)]
363
+ done_beams = sum(done_beams_table, [])
364
+ return done_beams
365
+
366
+ def sample_next_word(self, logprobs, sample_method, temperature):
367
+ if sample_method == 'greedy':
368
+ sampleLogprobs, it = torch.max(logprobs.data, 1)
369
+ it = it.view(-1).long()
370
+ elif sample_method == 'gumbel': # gumbel softmax
371
+ # ref: https://gist.github.com/yzh119/fd2146d2aeb329d067568a493b20172f
372
+ def sample_gumbel(shape, eps=1e-20):
373
+ U = torch.rand(shape).to(logprobs.device)
374
+ return -torch.log(-torch.log(U + eps) + eps)
375
+ def gumbel_softmax_sample(logits, temperature):
376
+ y = logits + sample_gumbel(logits.size())
377
+ return F.log_softmax(y / temperature, dim=-1)
378
+ _logprobs = gumbel_softmax_sample(logprobs, temperature)
379
+ _, it = torch.max(_logprobs.data, 1)
380
+ sampleLogprobs = logprobs.gather(1, it.unsqueeze(1)) # gather the logprobs at sampled positions
381
+ else:
382
+ logprobs = logprobs / temperature
383
+ if sample_method.startswith('top'): # topk sampling
384
+ top_num = float(sample_method[3:])
385
+ if 0 < top_num < 1:
386
+ # nucleus sampling from # The Curious Case of Neural Text Degeneration
387
+ probs = F.softmax(logprobs, dim=1)
388
+ sorted_probs, sorted_indices = torch.sort(probs, descending=True, dim=1)
389
+ _cumsum = sorted_probs.cumsum(1)
390
+ mask = _cumsum < top_num
391
+ mask = torch.cat([torch.ones_like(mask[:,:1]), mask[:,:-1]], 1)
392
+ sorted_probs = sorted_probs * mask.to(sorted_probs)
393
+ sorted_probs = sorted_probs / sorted_probs.sum(1, keepdim=True)
394
+ logprobs.scatter_(1, sorted_indices, sorted_probs.log())
395
+ else:
396
+ the_k = int(top_num)
397
+ tmp = torch.empty_like(logprobs).fill_(float('-inf'))
398
+ topk, indices = torch.topk(logprobs, the_k, dim=1)
399
+ tmp = tmp.scatter(1, indices, topk)
400
+ logprobs = tmp
401
+ it = torch.distributions.Categorical(logits=logprobs.detach()).sample()
402
+ sampleLogprobs = logprobs.gather(1, it.unsqueeze(1)) # gather the logprobs at sampled positions
403
+ return it, sampleLogprobs
404
+
405
+
406
+ def decode_sequence(self, seq):
407
+ return utils.decode_sequence(self.vocab, seq)
captioning/models/FCModel.py ADDED
@@ -0,0 +1,204 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import absolute_import
2
+ from __future__ import division
3
+ from __future__ import print_function
4
+
5
+ import torch
6
+ import torch.nn as nn
7
+ import torch.nn.functional as F
8
+ from torch.autograd import *
9
+ from . import utils
10
+
11
+ from .CaptionModel import CaptionModel
12
+
13
+ class LSTMCore(nn.Module):
14
+ def __init__(self, opt):
15
+ super(LSTMCore, self).__init__()
16
+ self.input_encoding_size = opt.input_encoding_size
17
+ self.rnn_size = opt.rnn_size
18
+ self.drop_prob_lm = opt.drop_prob_lm
19
+
20
+ # Build a LSTM
21
+ self.i2h = nn.Linear(self.input_encoding_size, 5 * self.rnn_size)
22
+ self.h2h = nn.Linear(self.rnn_size, 5 * self.rnn_size)
23
+ self.dropout = nn.Dropout(self.drop_prob_lm)
24
+
25
+ def forward(self, xt, state):
26
+
27
+ all_input_sums = self.i2h(xt) + self.h2h(state[0][-1])
28
+ sigmoid_chunk = all_input_sums.narrow(1, 0, 3 * self.rnn_size)
29
+ sigmoid_chunk = torch.sigmoid(sigmoid_chunk)
30
+ in_gate = sigmoid_chunk.narrow(1, 0, self.rnn_size)
31
+ forget_gate = sigmoid_chunk.narrow(1, self.rnn_size, self.rnn_size)
32
+ out_gate = sigmoid_chunk.narrow(1, self.rnn_size * 2, self.rnn_size)
33
+
34
+ in_transform = torch.max(\
35
+ all_input_sums.narrow(1, 3 * self.rnn_size, self.rnn_size),
36
+ all_input_sums.narrow(1, 4 * self.rnn_size, self.rnn_size))
37
+ next_c = forget_gate * state[1][-1] + in_gate * in_transform
38
+ next_h = out_gate * torch.tanh(next_c)
39
+
40
+ output = self.dropout(next_h)
41
+ state = (next_h.unsqueeze(0), next_c.unsqueeze(0))
42
+ return output, state
43
+
44
+ class FCModel(CaptionModel):
45
+ def __init__(self, opt):
46
+ super(FCModel, self).__init__()
47
+ self.vocab_size = opt.vocab_size
48
+ self.input_encoding_size = opt.input_encoding_size
49
+ self.rnn_type = opt.rnn_type
50
+ self.rnn_size = opt.rnn_size
51
+ self.num_layers = opt.num_layers
52
+ self.drop_prob_lm = opt.drop_prob_lm
53
+ self.seq_length = opt.seq_length
54
+ self.fc_feat_size = opt.fc_feat_size
55
+
56
+ self.ss_prob = 0.0 # Schedule sampling probability
57
+
58
+ self.img_embed = nn.Linear(self.fc_feat_size, self.input_encoding_size)
59
+ self.core = LSTMCore(opt)
60
+ self.embed = nn.Embedding(self.vocab_size + 1, self.input_encoding_size)
61
+ self.logit = nn.Linear(self.rnn_size, self.vocab_size + 1)
62
+
63
+ self.init_weights()
64
+
65
+ def init_weights(self):
66
+ initrange = 0.1
67
+ self.embed.weight.data.uniform_(-initrange, initrange)
68
+ self.logit.bias.data.fill_(0)
69
+ self.logit.weight.data.uniform_(-initrange, initrange)
70
+
71
+ def init_hidden(self, bsz):
72
+ weight = self.logit.weight
73
+ if self.rnn_type == 'lstm':
74
+ return (weight.new_zeros(self.num_layers, bsz, self.rnn_size),
75
+ weight.new_zeros(self.num_layers, bsz, self.rnn_size))
76
+ else:
77
+ return weight.new_zeros(self.num_layers, bsz, self.rnn_size)
78
+
79
+ def _forward(self, fc_feats, att_feats, seq, att_masks=None):
80
+ batch_size = fc_feats.size(0)
81
+ seq_per_img = seq.shape[0] // batch_size
82
+ state = self.init_hidden(batch_size*seq_per_img)
83
+ outputs = []
84
+
85
+ if seq_per_img > 1:
86
+ fc_feats = utils.repeat_tensors(seq_per_img, fc_feats)
87
+
88
+ for i in range(seq.size(1) + 1):
89
+ if i == 0:
90
+ xt = self.img_embed(fc_feats)
91
+ else:
92
+ if self.training and i >= 2 and self.ss_prob > 0.0: # otherwiste no need to sample
93
+ sample_prob = fc_feats.data.new(batch_size*seq_per_img).uniform_(0, 1)
94
+ sample_mask = sample_prob < self.ss_prob
95
+ if sample_mask.sum() == 0:
96
+ it = seq[:, i-1].clone()
97
+ else:
98
+ sample_ind = sample_mask.nonzero().view(-1)
99
+ it = seq[:, i-1].data.clone()
100
+ #prob_prev = torch.exp(outputs[-1].data.index_select(0, sample_ind)) # fetch prev distribution: shape Nx(M+1)
101
+ #it.index_copy_(0, sample_ind, torch.multinomial(prob_prev, 1).view(-1))
102
+ prob_prev = torch.exp(outputs[-1].data) # fetch prev distribution: shape Nx(M+1)
103
+ it.index_copy_(0, sample_ind, torch.multinomial(prob_prev, 1).view(-1).index_select(0, sample_ind))
104
+ else:
105
+ it = seq[:, i-1].clone()
106
+ # break if all the sequences end
107
+ if i >= 2 and seq[:, i-1].sum() == 0:
108
+ break
109
+ xt = self.embed(it)
110
+
111
+ output, state = self.core(xt, state)
112
+ output = F.log_softmax(self.logit(output), dim=1)
113
+ outputs.append(output)
114
+
115
+ return torch.cat([_.unsqueeze(1) for _ in outputs[1:]], 1).contiguous()
116
+
117
+ def get_logprobs_state(self, it, state):
118
+ # 'it' is contains a word index
119
+ xt = self.embed(it)
120
+
121
+ output, state = self.core(xt, state)
122
+ logprobs = F.log_softmax(self.logit(output), dim=1)
123
+
124
+ return logprobs, state
125
+
126
+ def _sample_beam(self, fc_feats, att_feats, att_masks=None, opt={}):
127
+ beam_size = opt.get('beam_size', 10)
128
+ batch_size = fc_feats.size(0)
129
+
130
+ assert beam_size <= self.vocab_size + 1, 'lets assume this for now, otherwise this corner case causes a few headaches down the road. can be dealt with in future if needed'
131
+ seq = torch.LongTensor(self.seq_length, batch_size).zero_()
132
+ seqLogprobs = torch.FloatTensor(self.seq_length, batch_size, self.vocab_size + 1)
133
+ # lets process every image independently for now, for simplicity
134
+
135
+ self.done_beams = [[] for _ in range(batch_size)]
136
+ for k in range(batch_size):
137
+ state = self.init_hidden(beam_size)
138
+ for t in range(2):
139
+ if t == 0:
140
+ xt = self.img_embed(fc_feats[k:k+1]).expand(beam_size, self.input_encoding_size)
141
+ elif t == 1: # input <bos>
142
+ it = fc_feats.data.new(beam_size).long().zero_()
143
+ xt = self.embed(it)
144
+
145
+ output, state = self.core(xt, state)
146
+ logprobs = F.log_softmax(self.logit(output), dim=1)
147
+
148
+ self.done_beams[k] = self.beam_search(state, logprobs, opt=opt)
149
+ seq[:, k] = self.done_beams[k][0]['seq'] # the first beam has highest cumulative score
150
+ seqLogprobs[:, k] = self.done_beams[k][0]['logps']
151
+ # return the samples and their log likelihoods
152
+ return seq.transpose(0, 1), seqLogprobs.transpose(0, 1)
153
+
154
+ def _sample(self, fc_feats, att_feats, att_masks=None, opt={}):
155
+ sample_method = opt.get('sample_method', 'greedy')
156
+ beam_size = opt.get('beam_size', 1)
157
+ temperature = opt.get('temperature', 1.0)
158
+ if beam_size > 1 and sample_method in ['greedy', 'beam_search']:
159
+ return self._sample_beam(fc_feats, att_feats, opt)
160
+
161
+ batch_size = fc_feats.size(0)
162
+ state = self.init_hidden(batch_size)
163
+ seq = fc_feats.new_zeros(batch_size, self.seq_length, dtype=torch.long)
164
+ seqLogprobs = fc_feats.new_zeros(batch_size, self.seq_length, self.vocab_size + 1)
165
+ for t in range(self.seq_length + 2):
166
+ if t == 0:
167
+ xt = self.img_embed(fc_feats)
168
+ else:
169
+ if t == 1: # input <bos>
170
+ it = fc_feats.data.new(batch_size).long().zero_()
171
+ xt = self.embed(it)
172
+
173
+ output, state = self.core(xt, state)
174
+ logprobs = F.log_softmax(self.logit(output), dim=1)
175
+
176
+ # sample the next_word
177
+ if t == self.seq_length + 1: # skip if we achieve maximum length
178
+ break
179
+ if sample_method == 'greedy':
180
+ sampleLogprobs, it = torch.max(logprobs.data, 1)
181
+ it = it.view(-1).long()
182
+ else:
183
+ if temperature == 1.0:
184
+ prob_prev = torch.exp(logprobs.data).cpu() # fetch prev distribution: shape Nx(M+1)
185
+ else:
186
+ # scale logprobs by temperature
187
+ prob_prev = torch.exp(torch.div(logprobs.data, temperature)).cpu()
188
+ it = torch.multinomial(prob_prev, 1).to(logprobs.device)
189
+ sampleLogprobs = logprobs.gather(1, it) # gather the logprobs at sampled positions
190
+ it = it.view(-1).long() # and flatten indices for downstream processing
191
+
192
+ if t >= 1:
193
+ # stop when all finished
194
+ if t == 1:
195
+ unfinished = it > 0
196
+ else:
197
+ unfinished = unfinished & (it > 0)
198
+ it = it * unfinished.type_as(it)
199
+ seq[:,t-1] = it #seq[t] the input of t+2 time step
200
+ seqLogprobs[:,t-1] = sampleLogprobs.view(-1)
201
+ if unfinished.sum() == 0:
202
+ break
203
+
204
+ return seq, seqLogprobs
captioning/models/M2Transformer.py ADDED
@@ -0,0 +1,98 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Instruction to use meshed_memory_transformer (https://arxiv.org/abs/1912.08226)
3
+
4
+ pip install git+https://github.com/ruotianluo/meshed-memory-transformer.git
5
+
6
+ Note:
7
+ Currently m2transformer is not performing as well as original transformer. Not sure why? Still investigating.
8
+ """
9
+
10
+ from __future__ import absolute_import
11
+ from __future__ import division
12
+ from __future__ import print_function
13
+
14
+ import torch
15
+ import torch.nn as nn
16
+ import torch.nn.functional as F
17
+
18
+ import copy
19
+ import math
20
+ import numpy as np
21
+
22
+ from .CaptionModel import CaptionModel
23
+ from .AttModel import sort_pack_padded_sequence, pad_unsort_packed_sequence, pack_wrapper, AttModel
24
+
25
+ try:
26
+ from m2transformer.models.transformer import Transformer, MemoryAugmentedEncoder, MeshedDecoder, ScaledDotProductAttentionMemory
27
+ except:
28
+ print('meshed-memory-transformer not installed; please run `pip install git+https://github.com/ruotianluo/meshed-memory-transformer.git`')
29
+ from .TransformerModel import subsequent_mask, TransformerModel
30
+
31
+
32
+ class M2TransformerModel(TransformerModel):
33
+
34
+ def make_model(self, src_vocab, tgt_vocab, N_enc=6, N_dec=6,
35
+ d_model=512, d_ff=2048, h=8, dropout=0.1):
36
+ "Helper: Construct a model from hyperparameters."
37
+ encoder = MemoryAugmentedEncoder(N_enc, 0, attention_module=ScaledDotProductAttentionMemory,
38
+ attention_module_kwargs={'m': 40})
39
+ # Another implementation is to use MultiLevelEncoder + att_embed
40
+ decoder = MeshedDecoder(tgt_vocab, 54, N_dec, -1) # -1 is padding;
41
+ model = Transformer(0, encoder, decoder) # 0 is bos
42
+ return model
43
+
44
+ def __init__(self, opt):
45
+ super(M2TransformerModel, self).__init__(opt)
46
+ delattr(self, 'att_embed')
47
+ self.att_embed = lambda x: x # The visual embed is in the MAEncoder
48
+ # Notes: The dropout in MAEncoder is different from my att_embed, mine is 0.5?
49
+ # Also the attention mask seems wrong in MAEncoder too...intersting
50
+
51
+ def logit(self, x): # unsafe way
52
+ return x # M2transformer always output logsoftmax
53
+
54
+ def _prepare_feature(self, fc_feats, att_feats, att_masks):
55
+
56
+ att_feats, seq, att_masks, seq_mask = self._prepare_feature_forward(att_feats, att_masks)
57
+ memory, att_masks = self.model.encoder(att_feats)
58
+
59
+ return fc_feats[...,:0], att_feats[...,:0], memory, att_masks
60
+
61
+ def _forward(self, fc_feats, att_feats, seq, att_masks=None):
62
+ if seq.ndim == 3: # B * seq_per_img * seq_len
63
+ seq = seq.reshape(-1, seq.shape[2])
64
+ att_feats, seq, att_masks, seq_mask = self._prepare_feature_forward(att_feats, att_masks, seq)
65
+
66
+ seq = seq.clone()
67
+ seq[~seq_mask.any(-2)] = -1 # Make padding to be -1 (my dataloader uses 0 as padding)
68
+ outputs = self.model(att_feats, seq)
69
+
70
+ return outputs
71
+
72
+ def core(self, it, fc_feats_ph, att_feats_ph, memory, state, mask):
73
+ """
74
+ state = [ys.unsqueeze(0)]
75
+ """
76
+ if len(state) == 0:
77
+ ys = it.unsqueeze(1)
78
+ else:
79
+ ys = torch.cat([state[0][0], it.unsqueeze(1)], dim=1)
80
+ out = self.model.decoder(ys, memory, mask)
81
+ return out[:, -1], [ys.unsqueeze(0)]
82
+
83
+ def _sample_beam(self, fc_feats, att_feats, att_masks=None, opt={}):
84
+ beam_size = opt.get('beam_size', 10)
85
+ group_size = opt.get('group_size', 1)
86
+ sample_n = opt.get('sample_n', 10)
87
+ assert sample_n == 1 or sample_n == beam_size // group_size, 'when beam search, sample_n == 1 or beam search'
88
+
89
+ att_feats, _, __, ___ = self._prepare_feature_forward(att_feats, att_masks)
90
+ seq, logprobs, seqLogprobs = self.model.beam_search(att_feats, self.seq_length, 0,
91
+ beam_size, return_probs=True, out_size=beam_size)
92
+ seq = seq.reshape(-1, *seq.shape[2:])
93
+ seqLogprobs = seqLogprobs.reshape(-1, *seqLogprobs.shape[2:])
94
+
95
+ # if not (seqLogprobs.gather(-1, seq.unsqueeze(-1)).squeeze(-1) == logprobs.reshape(-1, logprobs.shape[-1])).all():
96
+ # import pudb;pu.db
97
+ # seqLogprobs = logprobs.reshape(-1, logprobs.shape[-1]).unsqueeze(-1).expand(-1,-1,seqLogprobs.shape[-1])
98
+ return seq, seqLogprobs
captioning/models/ShowTellModel.py ADDED
@@ -0,0 +1,174 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import absolute_import
2
+ from __future__ import division
3
+ from __future__ import print_function
4
+
5
+ import torch
6
+ import torch.nn as nn
7
+ import torch.nn.functional as F
8
+ from torch.autograd import *
9
+ from . import utils
10
+
11
+ from .CaptionModel import CaptionModel
12
+
13
+ class ShowTellModel(CaptionModel):
14
+ def __init__(self, opt):
15
+ super(ShowTellModel, self).__init__()
16
+ self.vocab_size = opt.vocab_size
17
+ self.input_encoding_size = opt.input_encoding_size
18
+ self.rnn_type = opt.rnn_type
19
+ self.rnn_size = opt.rnn_size
20
+ self.num_layers = opt.num_layers
21
+ self.drop_prob_lm = opt.drop_prob_lm
22
+ self.seq_length = opt.seq_length
23
+ self.fc_feat_size = opt.fc_feat_size
24
+
25
+ self.ss_prob = 0.0 # Schedule sampling probability
26
+
27
+ self.img_embed = nn.Linear(self.fc_feat_size, self.input_encoding_size)
28
+ self.core = getattr(nn, self.rnn_type.upper())(self.input_encoding_size, self.rnn_size, self.num_layers, bias=False, dropout=self.drop_prob_lm)
29
+ self.embed = nn.Embedding(self.vocab_size + 1, self.input_encoding_size)
30
+ self.logit = nn.Linear(self.rnn_size, self.vocab_size + 1)
31
+ self.dropout = nn.Dropout(self.drop_prob_lm)
32
+
33
+ self.init_weights()
34
+
35
+ def init_weights(self):
36
+ initrange = 0.1
37
+ self.embed.weight.data.uniform_(-initrange, initrange)
38
+ self.logit.bias.data.fill_(0)
39
+ self.logit.weight.data.uniform_(-initrange, initrange)
40
+
41
+ def init_hidden(self, bsz):
42
+ weight = self.logit.weight
43
+ if self.rnn_type == 'lstm':
44
+ return (weight.new_zeros(self.num_layers, bsz, self.rnn_size),
45
+ weight.new_zeros(self.num_layers, bsz, self.rnn_size))
46
+ else:
47
+ return weight.new_zeros(self.num_layers, bsz, self.rnn_size)
48
+
49
+ def _forward(self, fc_feats, att_feats, seq, att_masks=None):
50
+ batch_size = fc_feats.size(0)
51
+ seq_per_img = seq.shape[0] // batch_size
52
+ state = self.init_hidden(batch_size*seq_per_img)
53
+ outputs = []
54
+
55
+ if seq_per_img > 1:
56
+ fc_feats = utils.repeat_tensors(seq_per_img, fc_feats)
57
+
58
+ for i in range(seq.size(1) + 1):
59
+ if i == 0:
60
+ xt = self.img_embed(fc_feats)
61
+ else:
62
+ if self.training and i >= 2 and self.ss_prob > 0.0: # otherwiste no need to sample
63
+ sample_prob = fc_feats.data.new(batch_size*seq_per_img).uniform_(0, 1)
64
+ sample_mask = sample_prob < self.ss_prob
65
+ if sample_mask.sum() == 0:
66
+ it = seq[:, i-1].clone()
67
+ else:
68
+ sample_ind = sample_mask.nonzero().view(-1)
69
+ it = seq[:, i-1].data.clone()
70
+ #prob_prev = torch.exp(outputs[-1].data.index_select(0, sample_ind)) # fetch prev distribution: shape Nx(M+1)
71
+ #it.index_copy_(0, sample_ind, torch.multinomial(prob_prev, 1).view(-1))
72
+ prob_prev = torch.exp(outputs[-1].data) # fetch prev distribution: shape Nx(M+1)
73
+ it.index_copy_(0, sample_ind, torch.multinomial(prob_prev, 1).view(-1).index_select(0, sample_ind))
74
+ else:
75
+ it = seq[:, i-1].clone()
76
+ # break if all the sequences end
77
+ if i >= 2 and seq[:, i-1].data.sum() == 0:
78
+ break
79
+ xt = self.embed(it)
80
+
81
+ output, state = self.core(xt.unsqueeze(0), state)
82
+ output = F.log_softmax(self.logit(self.dropout(output.squeeze(0))), dim=1)
83
+ outputs.append(output)
84
+
85
+ return torch.cat([_.unsqueeze(1) for _ in outputs[1:]], 1).contiguous()
86
+
87
+ def get_logprobs_state(self, it, state):
88
+ # 'it' contains a word index
89
+ xt = self.embed(it)
90
+
91
+ output, state = self.core(xt.unsqueeze(0), state)
92
+ logprobs = F.log_softmax(self.logit(self.dropout(output.squeeze(0))), dim=1)
93
+
94
+ return logprobs, state
95
+
96
+ def _sample_beam(self, fc_feats, att_feats, att_masks=None, opt={}):
97
+ beam_size = opt.get('beam_size', 10)
98
+ batch_size = fc_feats.size(0)
99
+
100
+ assert beam_size <= self.vocab_size + 1, 'lets assume this for now, otherwise this corner case causes a few headaches down the road. can be dealt with in future if needed'
101
+ seq = torch.LongTensor(self.seq_length, batch_size).zero_()
102
+ seqLogprobs = torch.FloatTensor(self.seq_length, batch_size)
103
+ # lets process every image independently for now, for simplicity
104
+
105
+ self.done_beams = [[] for _ in range(batch_size)]
106
+ for k in range(batch_size):
107
+ state = self.init_hidden(beam_size)
108
+ for t in range(2):
109
+ if t == 0:
110
+ xt = self.img_embed(fc_feats[k:k+1]).expand(beam_size, self.input_encoding_size)
111
+ elif t == 1: # input <bos>
112
+ it = fc_feats.data.new(beam_size).long().zero_()
113
+ xt = self.embed(it)
114
+
115
+ output, state = self.core(xt.unsqueeze(0), state)
116
+ logprobs = F.log_softmax(self.logit(self.dropout(output.squeeze(0))), dim=1)
117
+
118
+ self.done_beams[k] = self.beam_search(state, logprobs, opt=opt)
119
+ seq[:, k] = self.done_beams[k][0]['seq'] # the first beam has highest cumulative score
120
+ seqLogprobs[:, k] = self.done_beams[k][0]['logps']
121
+ # return the samples and their log likelihoods
122
+ return seq.transpose(0, 1), seqLogprobs.transpose(0, 1)
123
+
124
+ def _sample(self, fc_feats, att_feats, att_masks=None, opt={}):
125
+ sample_method = opt.get('sample_method', 'greedy')
126
+ beam_size = opt.get('beam_size', 1)
127
+ temperature = opt.get('temperature', 1.0)
128
+ if beam_size > 1 and sample_method in ['greedy', 'beam_search']:
129
+ return self.sample_beam(fc_feats, att_feats, opt)
130
+
131
+ batch_size = fc_feats.size(0)
132
+ state = self.init_hidden(batch_size)
133
+ seq = fc_feats.new_zeros(batch_size, self.seq_length, dtype=torch.long)
134
+ seqLogprobs = fc_feats.new_zeros(batch_size, self.seq_length)
135
+ for t in range(self.seq_length + 2):
136
+ if t == 0:
137
+ xt = self.img_embed(fc_feats)
138
+ else:
139
+ if t == 1: # input <bos>
140
+ it = fc_feats.data.new(batch_size).long().zero_()
141
+ xt = self.embed(it)
142
+
143
+ output, state = self.core(xt.unsqueeze(0), state)
144
+ logprobs = F.log_softmax(self.logit(self.dropout(output.squeeze(0))), dim=1)
145
+
146
+ # sample the next word
147
+ if t == self.seq_length + 1: # skip if we achieve maximum length
148
+ break
149
+ if sample_method == 'greedy':
150
+ sampleLogprobs, it = torch.max(logprobs.data, 1)
151
+ it = it.view(-1).long()
152
+ else:
153
+ if temperature == 1.0:
154
+ prob_prev = torch.exp(logprobs.data).cpu() # fetch prev distribution: shape Nx(M+1)
155
+ else:
156
+ # scale logprobs by temperature
157
+ prob_prev = torch.exp(torch.div(logprobs.data, temperature)).cpu()
158
+ it = torch.multinomial(prob_prev, 1).to(logprobs.device)
159
+ sampleLogprobs = logprobs.gather(1, it) # gather the logprobs at sampled positions
160
+ it = it.view(-1).long() # and flatten indices for downstream processing
161
+
162
+ if t >= 1:
163
+ # stop when all finished
164
+ if t == 1:
165
+ unfinished = it > 0
166
+ else:
167
+ unfinished = unfinished & (it > 0)
168
+ it = it * unfinished.type_as(it)
169
+ seq[:,t-1] = it #seq[t] the input of t+2 time step
170
+ seqLogprobs[:,t-1] = sampleLogprobs.view(-1)
171
+ if unfinished.sum() == 0:
172
+ break
173
+
174
+ return seq, seqLogprobs
captioning/models/TransformerModel.py ADDED
@@ -0,0 +1,363 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # This file contains Transformer network
2
+ # Most of the code is copied from http://nlp.seas.harvard.edu/2018/04/03/attention.html
3
+
4
+ # The cfg name correspondance:
5
+ # N=num_layers
6
+ # d_model=input_encoding_size
7
+ # d_ff=rnn_size
8
+ # h is always 8
9
+
10
+ from __future__ import absolute_import
11
+ from __future__ import division
12
+ from __future__ import print_function
13
+
14
+ import torch
15
+ import torch.nn as nn
16
+ import torch.nn.functional as F
17
+ from . import utils
18
+
19
+ import copy
20
+ import math
21
+ import numpy as np
22
+
23
+ from .CaptionModel import CaptionModel
24
+ from .AttModel import sort_pack_padded_sequence, pad_unsort_packed_sequence, pack_wrapper, AttModel
25
+
26
+ class EncoderDecoder(nn.Module):
27
+ """
28
+ A standard Encoder-Decoder architecture. Base for this and many
29
+ other models.
30
+ """
31
+ def __init__(self, encoder, decoder, src_embed, tgt_embed, generator):
32
+ super(EncoderDecoder, self).__init__()
33
+ self.encoder = encoder
34
+ self.decoder = decoder
35
+ self.src_embed = src_embed
36
+ self.tgt_embed = tgt_embed
37
+ self.generator = generator
38
+
39
+ def forward(self, src, tgt, src_mask, tgt_mask):
40
+ "Take in and process masked src and target sequences."
41
+ return self.decode(self.encode(src, src_mask), src_mask,
42
+ tgt, tgt_mask)
43
+
44
+ def encode(self, src, src_mask):
45
+ return self.encoder(self.src_embed(src), src_mask)
46
+
47
+ def decode(self, memory, src_mask, tgt, tgt_mask):
48
+ return self.decoder(self.tgt_embed(tgt), memory, src_mask, tgt_mask)
49
+
50
+ class Generator(nn.Module):
51
+ "Define standard linear + softmax generation step."
52
+ def __init__(self, d_model, vocab):
53
+ super(Generator, self).__init__()
54
+ self.proj = nn.Linear(d_model, vocab)
55
+
56
+ def forward(self, x):
57
+ return F.log_softmax(self.proj(x), dim=-1)
58
+
59
+ def clones(module, N):
60
+ "Produce N identical layers."
61
+ return nn.ModuleList([copy.deepcopy(module) for _ in range(N)])
62
+
63
+ class Encoder(nn.Module):
64
+ "Core encoder is a stack of N layers"
65
+ def __init__(self, layer, N):
66
+ super(Encoder, self).__init__()
67
+ self.layers = clones(layer, N)
68
+ self.norm = LayerNorm(layer.size)
69
+
70
+ def forward(self, x, mask):
71
+ "Pass the input (and mask) through each layer in turn."
72
+ for layer in self.layers:
73
+ x = layer(x, mask)
74
+ return self.norm(x)
75
+
76
+ class LayerNorm(nn.Module):
77
+ "Construct a layernorm module (See citation for details)."
78
+ def __init__(self, features, eps=1e-6):
79
+ super(LayerNorm, self).__init__()
80
+ self.a_2 = nn.Parameter(torch.ones(features))
81
+ self.b_2 = nn.Parameter(torch.zeros(features))
82
+ self.eps = eps
83
+
84
+ def forward(self, x):
85
+ mean = x.mean(-1, keepdim=True)
86
+ std = x.std(-1, keepdim=True)
87
+ return self.a_2 * (x - mean) / (std + self.eps) + self.b_2
88
+
89
+ class SublayerConnection(nn.Module):
90
+ """
91
+ A residual connection followed by a layer norm.
92
+ Note for code simplicity the norm is first as opposed to last.
93
+ """
94
+ def __init__(self, size, dropout):
95
+ super(SublayerConnection, self).__init__()
96
+ self.norm = LayerNorm(size)
97
+ self.dropout = nn.Dropout(dropout)
98
+
99
+ def forward(self, x, sublayer):
100
+ "Apply residual connection to any sublayer with the same size."
101
+ return x + self.dropout(sublayer(self.norm(x)))
102
+
103
+ class EncoderLayer(nn.Module):
104
+ "Encoder is made up of self-attn and feed forward (defined below)"
105
+ def __init__(self, size, self_attn, feed_forward, dropout):
106
+ super(EncoderLayer, self).__init__()
107
+ self.self_attn = self_attn
108
+ self.feed_forward = feed_forward
109
+ self.sublayer = clones(SublayerConnection(size, dropout), 2)
110
+ self.size = size
111
+
112
+ def forward(self, x, mask):
113
+ "Follow Figure 1 (left) for connections."
114
+ x = self.sublayer[0](x, lambda x: self.self_attn(x, x, x, mask))
115
+ return self.sublayer[1](x, self.feed_forward)
116
+
117
+ class Decoder(nn.Module):
118
+ "Generic N layer decoder with masking."
119
+ def __init__(self, layer, N):
120
+ super(Decoder, self).__init__()
121
+ self.layers = clones(layer, N)
122
+ self.norm = LayerNorm(layer.size)
123
+
124
+ def forward(self, x, memory, src_mask, tgt_mask):
125
+ for layer in self.layers:
126
+ x = layer(x, memory, src_mask, tgt_mask)
127
+ return self.norm(x)
128
+
129
+ class DecoderLayer(nn.Module):
130
+ "Decoder is made of self-attn, src-attn, and feed forward (defined below)"
131
+ def __init__(self, size, self_attn, src_attn, feed_forward, dropout):
132
+ super(DecoderLayer, self).__init__()
133
+ self.size = size
134
+ self.self_attn = self_attn
135
+ self.src_attn = src_attn
136
+ self.feed_forward = feed_forward
137
+ self.sublayer = clones(SublayerConnection(size, dropout), 3)
138
+
139
+ def forward(self, x, memory, src_mask, tgt_mask):
140
+ "Follow Figure 1 (right) for connections."
141
+ m = memory
142
+ x = self.sublayer[0](x, lambda x: self.self_attn(x, x, x, tgt_mask))
143
+ x = self.sublayer[1](x, lambda x: self.src_attn(x, m, m, src_mask))
144
+ return self.sublayer[2](x, self.feed_forward)
145
+
146
+ def subsequent_mask(size):
147
+ "Mask out subsequent positions."
148
+ attn_shape = (1, size, size)
149
+ subsequent_mask = np.triu(np.ones(attn_shape), k=1).astype('uint8')
150
+ return torch.from_numpy(subsequent_mask) == 0
151
+
152
+ def attention(query, key, value, mask=None, dropout=None):
153
+ "Compute 'Scaled Dot Product Attention'"
154
+ d_k = query.size(-1)
155
+ scores = torch.matmul(query, key.transpose(-2, -1)) \
156
+ / math.sqrt(d_k)
157
+ if mask is not None:
158
+ scores = scores.masked_fill(mask == 0, float('-inf'))
159
+ p_attn = F.softmax(scores, dim = -1)
160
+ if dropout is not None:
161
+ p_attn = dropout(p_attn)
162
+ return torch.matmul(p_attn, value), p_attn
163
+
164
+ class MultiHeadedAttention(nn.Module):
165
+ def __init__(self, h, d_model, dropout=0.1):
166
+ "Take in model size and number of heads."
167
+ super(MultiHeadedAttention, self).__init__()
168
+ assert d_model % h == 0
169
+ # We assume d_v always equals d_k
170
+ self.d_k = d_model // h
171
+ self.h = h
172
+ self.linears = clones(nn.Linear(d_model, d_model), 4)
173
+ self.attn = None
174
+ self.dropout = nn.Dropout(p=dropout)
175
+
176
+ def forward(self, query, key, value, mask=None):
177
+ "Implements Figure 2"
178
+ if mask is not None:
179
+ # Same mask applied to all h heads.
180
+ mask = mask.unsqueeze(1)
181
+ nbatches = query.size(0)
182
+
183
+ # 1) Do all the linear projections in batch from d_model => h x d_k
184
+ query, key, value = \
185
+ [l(x).view(nbatches, -1, self.h, self.d_k).transpose(1, 2)
186
+ for l, x in zip(self.linears, (query, key, value))]
187
+
188
+ # 2) Apply attention on all the projected vectors in batch.
189
+ x, self.attn = attention(query, key, value, mask=mask,
190
+ dropout=self.dropout)
191
+
192
+ # 3) "Concat" using a view and apply a final linear.
193
+ x = x.transpose(1, 2).contiguous() \
194
+ .view(nbatches, -1, self.h * self.d_k)
195
+ return self.linears[-1](x)
196
+
197
+ class PositionwiseFeedForward(nn.Module):
198
+ "Implements FFN equation."
199
+ def __init__(self, d_model, d_ff, dropout=0.1):
200
+ super(PositionwiseFeedForward, self).__init__()
201
+ self.w_1 = nn.Linear(d_model, d_ff)
202
+ self.w_2 = nn.Linear(d_ff, d_model)
203
+ self.dropout = nn.Dropout(dropout)
204
+
205
+ def forward(self, x):
206
+ return self.w_2(self.dropout(F.relu(self.w_1(x))))
207
+
208
+ class Embeddings(nn.Module):
209
+ def __init__(self, d_model, vocab):
210
+ super(Embeddings, self).__init__()
211
+ self.lut = nn.Embedding(vocab, d_model)
212
+ self.d_model = d_model
213
+
214
+ def forward(self, x):
215
+ return self.lut(x) * math.sqrt(self.d_model)
216
+
217
+ class PositionalEncoding(nn.Module):
218
+ "Implement the PE function."
219
+ def __init__(self, d_model, dropout, max_len=5000):
220
+ super(PositionalEncoding, self).__init__()
221
+ self.dropout = nn.Dropout(p=dropout)
222
+
223
+ # Compute the positional encodings once in log space.
224
+ pe = torch.zeros(max_len, d_model)
225
+ position = torch.arange(0, max_len).unsqueeze(1).float()
226
+ div_term = torch.exp(torch.arange(0, d_model, 2).float() *
227
+ -(math.log(10000.0) / d_model))
228
+ pe[:, 0::2] = torch.sin(position * div_term)
229
+ pe[:, 1::2] = torch.cos(position * div_term)
230
+ pe = pe.unsqueeze(0)
231
+ self.register_buffer('pe', pe)
232
+
233
+ def forward(self, x):
234
+ x = x + self.pe[:, :x.size(1)]
235
+ return self.dropout(x)
236
+
237
+ class TransformerModel(AttModel):
238
+
239
+ def make_model(self, src_vocab, tgt_vocab, N_enc=6, N_dec=6,
240
+ d_model=512, d_ff=2048, h=8, dropout=0.1):
241
+ "Helper: Construct a model from hyperparameters."
242
+ c = copy.deepcopy
243
+ attn = MultiHeadedAttention(h, d_model, dropout)
244
+ ff = PositionwiseFeedForward(d_model, d_ff, dropout)
245
+ position = PositionalEncoding(d_model, dropout)
246
+ model = EncoderDecoder(
247
+ Encoder(EncoderLayer(d_model, c(attn), c(ff), dropout), N_enc),
248
+ Decoder(DecoderLayer(d_model, c(attn), c(attn),
249
+ c(ff), dropout), N_dec),
250
+ lambda x:x, # nn.Sequential(Embeddings(d_model, src_vocab), c(position)),
251
+ nn.Sequential(Embeddings(d_model, tgt_vocab), c(position)),
252
+ Generator(d_model, tgt_vocab))
253
+
254
+ # This was important from their code.
255
+ # Initialize parameters with Glorot / fan_avg.
256
+ for p in model.parameters():
257
+ if p.dim() > 1:
258
+ nn.init.xavier_uniform_(p)
259
+ return model
260
+
261
+ def __init__(self, opt):
262
+ super(TransformerModel, self).__init__(opt)
263
+ self.opt = opt
264
+ # self.config = yaml.load(open(opt.config_file))
265
+
266
+ self.N_enc = getattr(opt, 'N_enc', opt.num_layers)
267
+ self.N_dec = getattr(opt, 'N_dec', opt.num_layers)
268
+ self.d_model = getattr(opt, 'd_model', opt.input_encoding_size)
269
+ self.d_ff = getattr(opt, 'd_ff', opt.rnn_size)
270
+ self.h = getattr(opt, 'num_att_heads', 8)
271
+ self.dropout = getattr(opt, 'dropout', 0.1)
272
+
273
+ delattr(self, 'att_embed')
274
+ self.att_embed = nn.Sequential(*(
275
+ ((nn.BatchNorm1d(self.att_feat_size),) if self.use_bn else ())+
276
+ (nn.Linear(self.att_feat_size, self.d_model),
277
+ nn.ReLU(),
278
+ nn.Dropout(self.drop_prob_lm))+
279
+ ((nn.BatchNorm1d(self.d_model),) if self.use_bn==2 else ())))
280
+
281
+ delattr(self, 'embed')
282
+ self.embed = lambda x : x
283
+ delattr(self, 'fc_embed')
284
+ self.fc_embed = lambda x : x
285
+ delattr(self, 'logit')
286
+ del self.ctx2att
287
+
288
+ tgt_vocab = self.vocab_size + 1
289
+
290
+
291
+ self.model = self.make_model(0, tgt_vocab,
292
+ N_enc=self.N_enc,
293
+ N_dec=self.N_dec,
294
+ d_model=self.d_model,
295
+ d_ff=self.d_ff,
296
+ h=self.h,
297
+ dropout=self.dropout)
298
+
299
+ def logit(self, x): # unsafe way
300
+ return self.model.generator.proj(x)
301
+
302
+ def init_hidden(self, bsz):
303
+ return []
304
+
305
+ def _prepare_feature(self, fc_feats, att_feats, att_masks):
306
+
307
+ att_feats, seq, att_masks, seq_mask = self._prepare_feature_forward(att_feats, att_masks)
308
+ memory = self.model.encode(att_feats, att_masks)
309
+
310
+ return fc_feats[...,:0], att_feats[...,:0], memory, att_masks
311
+
312
+ def _prepare_feature_forward(self, att_feats, att_masks=None, seq=None):
313
+ att_feats, att_masks = self.clip_att(att_feats, att_masks)
314
+
315
+ att_feats = pack_wrapper(self.att_embed, att_feats, att_masks)
316
+
317
+ if att_masks is None:
318
+ att_masks = att_feats.new_ones(att_feats.shape[:2], dtype=torch.long)
319
+ att_masks = att_masks.unsqueeze(-2)
320
+
321
+ if seq is not None:
322
+ # crop the last one
323
+ # seq = seq[:,:-1]
324
+ seq_mask = (seq.data != self.eos_idx) & (seq.data != self.pad_idx)
325
+ seq_mask[:,0] = 1 # bos
326
+
327
+ seq_mask = seq_mask.unsqueeze(-2)
328
+ seq_mask = seq_mask & subsequent_mask(seq.size(-1)).to(seq_mask)
329
+
330
+ seq_per_img = seq.shape[0] // att_feats.shape[0]
331
+ if seq_per_img > 1:
332
+ att_feats, att_masks = utils.repeat_tensors(seq_per_img,
333
+ [att_feats, att_masks]
334
+ )
335
+ else:
336
+ seq_mask = None
337
+
338
+ return att_feats, seq, att_masks, seq_mask
339
+
340
+ def _forward(self, fc_feats, att_feats, seq, att_masks=None):
341
+ if seq.ndim == 3: # B * seq_per_img * seq_len
342
+ seq = seq.reshape(-1, seq.shape[2])
343
+ att_feats, seq, att_masks, seq_mask = self._prepare_feature_forward(att_feats, att_masks, seq)
344
+
345
+ out = self.model(att_feats, seq, att_masks, seq_mask)
346
+
347
+ outputs = self.model.generator(out)
348
+ return outputs
349
+ # return torch.cat([_.unsqueeze(1) for _ in outputs], 1)
350
+
351
+ def core(self, it, fc_feats_ph, att_feats_ph, memory, state, mask):
352
+ """
353
+ state = [ys.unsqueeze(0)]
354
+ """
355
+ if len(state) == 0:
356
+ ys = it.unsqueeze(1)
357
+ else:
358
+ ys = torch.cat([state[0][0], it.unsqueeze(1)], dim=1)
359
+ out = self.model.decode(memory, mask,
360
+ ys,
361
+ subsequent_mask(ys.size(1))
362
+ .to(memory.device))
363
+ return out[:, -1], [ys.unsqueeze(0)]
captioning/models/__init__.py ADDED
@@ -0,0 +1,73 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import absolute_import
2
+ from __future__ import division
3
+ from __future__ import print_function
4
+
5
+ import os
6
+ import copy
7
+
8
+ import numpy as np
9
+ import torch
10
+
11
+ from .ShowTellModel import ShowTellModel
12
+ from .FCModel import FCModel
13
+ from .AttModel import *
14
+ from .TransformerModel import TransformerModel
15
+ from .cachedTransformer import TransformerModel as cachedTransformer
16
+ from .BertCapModel import BertCapModel
17
+ from .M2Transformer import M2TransformerModel
18
+ from .AoAModel import AoAModel
19
+
20
+ def setup(opt):
21
+ if opt.caption_model in ['fc', 'show_tell']:
22
+ print('Warning: %s model is mostly deprecated; many new features are not supported.' %opt.caption_model)
23
+ if opt.caption_model == 'fc':
24
+ print('Use newfc instead of fc')
25
+ if opt.caption_model == 'fc':
26
+ model = FCModel(opt)
27
+ elif opt.caption_model == 'language_model':
28
+ model = LMModel(opt)
29
+ elif opt.caption_model == 'newfc':
30
+ model = NewFCModel(opt)
31
+ elif opt.caption_model == 'show_tell':
32
+ model = ShowTellModel(opt)
33
+ # Att2in model in self-critical
34
+ elif opt.caption_model == 'att2in':
35
+ model = Att2inModel(opt)
36
+ # Att2in model with two-layer MLP img embedding and word embedding
37
+ elif opt.caption_model == 'att2in2':
38
+ model = Att2in2Model(opt)
39
+ elif opt.caption_model == 'att2all2':
40
+ print('Warning: this is not a correct implementation of the att2all model in the original paper.')
41
+ model = Att2all2Model(opt)
42
+ # Adaptive Attention model from Knowing when to look
43
+ elif opt.caption_model == 'adaatt':
44
+ model = AdaAttModel(opt)
45
+ # Adaptive Attention with maxout lstm
46
+ elif opt.caption_model == 'adaattmo':
47
+ model = AdaAttMOModel(opt)
48
+ # Top-down attention model
49
+ elif opt.caption_model in ['topdown', 'updown']:
50
+ model = UpDownModel(opt)
51
+ # StackAtt
52
+ elif opt.caption_model == 'stackatt':
53
+ model = StackAttModel(opt)
54
+ # DenseAtt
55
+ elif opt.caption_model == 'denseatt':
56
+ model = DenseAttModel(opt)
57
+ # Transformer
58
+ elif opt.caption_model == 'transformer':
59
+ if getattr(opt, 'cached_transformer', False):
60
+ model = cachedTransformer(opt)
61
+ else:
62
+ model = TransformerModel(opt)
63
+ # AoANet
64
+ elif opt.caption_model == 'aoa':
65
+ model = AoAModel(opt)
66
+ elif opt.caption_model == 'bert':
67
+ model = BertCapModel(opt)
68
+ elif opt.caption_model == 'm2transformer':
69
+ model = M2TransformerModel(opt)
70
+ else:
71
+ raise Exception("Caption model not supported: {}".format(opt.caption_model))
72
+
73
+ return model
captioning/models/cachedTransformer.py ADDED
@@ -0,0 +1,420 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # This file contains Transformer network
2
+ # Most of the code is copied from http://nlp.seas.harvard.edu/2018/04/03/attention.html
3
+
4
+ # The cfg name correspondance:
5
+ # N=num_layers
6
+ # d_model=input_encoding_size
7
+ # d_ff=rnn_size
8
+ # h is always 8
9
+
10
+ from __future__ import absolute_import
11
+ from __future__ import division
12
+ from __future__ import print_function
13
+
14
+ import torch
15
+ import torch.nn as nn
16
+ import torch.nn.functional as F
17
+ from . import utils
18
+
19
+ import copy
20
+ import math
21
+ import numpy as np
22
+
23
+ from .CaptionModel import CaptionModel
24
+ from .AttModel import sort_pack_padded_sequence, pad_unsort_packed_sequence, pack_wrapper, AttModel
25
+
26
+ class EncoderDecoder(nn.Module):
27
+ """
28
+ A standard Encoder-Decoder architecture. Base for this and many
29
+ other models.
30
+ """
31
+ def __init__(self, encoder, decoder, src_embed, tgt_embed, generator):
32
+ super(EncoderDecoder, self).__init__()
33
+ self.encoder = encoder
34
+ self.decoder = decoder
35
+ self.src_embed = src_embed
36
+ self.tgt_embed = tgt_embed
37
+ self.generator = generator
38
+
39
+ def forward(self, src, tgt, src_mask, tgt_mask):
40
+ "Take in and process masked src and target sequences."
41
+ return self.decode(self.encode(src, src_mask), src_mask,
42
+ tgt, tgt_mask)
43
+
44
+ def encode(self, src, src_mask):
45
+ return self.encoder(self.src_embed(src), src_mask)
46
+
47
+ def decode(self, memory, src_mask, tgt, tgt_mask, past=None):
48
+ return self.decoder(self.tgt_embed(tgt), memory, src_mask, tgt_mask, past=past)
49
+
50
+ class Generator(nn.Module):
51
+ "Define standard linear + softmax generation step."
52
+ def __init__(self, d_model, vocab):
53
+ super(Generator, self).__init__()
54
+ self.proj = nn.Linear(d_model, vocab)
55
+
56
+ def forward(self, x):
57
+ return F.log_softmax(self.proj(x), dim=-1)
58
+
59
+ def clones(module, N):
60
+ "Produce N identical layers."
61
+ return nn.ModuleList([copy.deepcopy(module) for _ in range(N)])
62
+
63
+ class Encoder(nn.Module):
64
+ "Core encoder is a stack of N layers"
65
+ def __init__(self, layer, N):
66
+ super(Encoder, self).__init__()
67
+ self.layers = clones(layer, N)
68
+ self.norm = LayerNorm(layer.size)
69
+
70
+ def forward(self, x, mask):
71
+ "Pass the input (and mask) through each layer in turn."
72
+ for layer in self.layers:
73
+ x = layer(x, mask)
74
+ return self.norm(x)
75
+
76
+ class LayerNorm(nn.Module):
77
+ "Construct a layernorm module (See citation for details)."
78
+ def __init__(self, features, eps=1e-6):
79
+ super(LayerNorm, self).__init__()
80
+ self.a_2 = nn.Parameter(torch.ones(features))
81
+ self.b_2 = nn.Parameter(torch.zeros(features))
82
+ self.eps = eps
83
+
84
+ def forward(self, x):
85
+ mean = x.mean(-1, keepdim=True)
86
+ std = x.std(-1, keepdim=True)
87
+ return self.a_2 * (x - mean) / (std + self.eps) + self.b_2
88
+
89
+ class SublayerConnection(nn.Module):
90
+ """
91
+ A residual connection followed by a layer norm.
92
+ Note for code simplicity the norm is first as opposed to last.
93
+ """
94
+ def __init__(self, size, dropout):
95
+ super(SublayerConnection, self).__init__()
96
+ self.norm = LayerNorm(size)
97
+ self.dropout = nn.Dropout(dropout)
98
+
99
+ def forward(self, x, sublayer):
100
+ "Apply residual connection to any sublayer with the same size."
101
+ _x = sublayer(self.norm(x))
102
+ if type(_x) is tuple: # for multi-head attention that returns past
103
+ return x + self.dropout(_x[0]), _x[1]
104
+ return x + self.dropout(_x)
105
+
106
+ class EncoderLayer(nn.Module):
107
+ "Encoder is made up of self-attn and feed forward (defined below)"
108
+ def __init__(self, size, self_attn, feed_forward, dropout):
109
+ super(EncoderLayer, self).__init__()
110
+ self.self_attn = self_attn
111
+ self.feed_forward = feed_forward
112
+ self.sublayer = clones(SublayerConnection(size, dropout), 2)
113
+ self.size = size
114
+
115
+ def forward(self, x, mask):
116
+ "Follow Figure 1 (left) for connections."
117
+ x = self.sublayer[0](x, lambda x: self.self_attn(x, x, x, mask))
118
+ return self.sublayer[1](x, self.feed_forward)
119
+
120
+ class Decoder(nn.Module):
121
+ "Generic N layer decoder with masking."
122
+ def __init__(self, layer, N):
123
+ super(Decoder, self).__init__()
124
+ self.layers = clones(layer, N)
125
+ self.norm = LayerNorm(layer.size)
126
+
127
+ def forward(self, x, memory, src_mask, tgt_mask, past=None):
128
+ if past is not None:
129
+ present = [[], []]
130
+ x = x[:, -1:]
131
+ tgt_mask = tgt_mask[:, -1:] if tgt_mask is not None else None
132
+ past = list(zip(past[0].split(2, dim=0), past[1].split(2, dim=0)))
133
+ else:
134
+ past = [None] * len(self.layers)
135
+ for i, (layer, layer_past) in enumerate(zip(self.layers, past)):
136
+ x = layer(x, memory, src_mask, tgt_mask,
137
+ layer_past)
138
+ if layer_past is not None:
139
+ present[0].append(x[1][0])
140
+ present[1].append(x[1][1])
141
+ x = x[0]
142
+ if past[0] is None:
143
+ return self.norm(x)
144
+ else:
145
+ return self.norm(x), [torch.cat(present[0], 0), torch.cat(present[1], 0)]
146
+
147
+
148
+ class DecoderLayer(nn.Module):
149
+ "Decoder is made of self-attn, src-attn, and feed forward (defined below)"
150
+ def __init__(self, size, self_attn, src_attn, feed_forward, dropout):
151
+ super(DecoderLayer, self).__init__()
152
+ self.size = size
153
+ self.self_attn = self_attn
154
+ self.src_attn = src_attn
155
+ self.feed_forward = feed_forward
156
+ self.sublayer = clones(SublayerConnection(size, dropout), 3)
157
+
158
+ def forward(self, x, memory, src_mask, tgt_mask, layer_past=None):
159
+ "Follow Figure 1 (right) for connections."
160
+ m = memory
161
+ if layer_past is None:
162
+ x = self.sublayer[0](x, lambda x: self.self_attn(x, x, x, tgt_mask))
163
+ x = self.sublayer[1](x, lambda x: self.src_attn(x, m, m, src_mask))
164
+ return self.sublayer[2](x, self.feed_forward)
165
+ else:
166
+ present = [None, None]
167
+ x, present[0] = self.sublayer[0](x, lambda x: self.self_attn(x, x, x, tgt_mask, layer_past[0]))
168
+ x, present[1] = self.sublayer[1](x, lambda x: self.src_attn(x, m, m, src_mask, layer_past[1]))
169
+ return self.sublayer[2](x, self.feed_forward), present
170
+
171
+ def subsequent_mask(size):
172
+ "Mask out subsequent positions."
173
+ attn_shape = (1, size, size)
174
+ subsequent_mask = np.triu(np.ones(attn_shape), k=1).astype('uint8')
175
+ return torch.from_numpy(subsequent_mask) == 0
176
+
177
+ def attention(query, key, value, mask=None, dropout=None):
178
+ "Compute 'Scaled Dot Product Attention'"
179
+ d_k = query.size(-1)
180
+ scores = torch.matmul(query, key.transpose(-2, -1)) \
181
+ / math.sqrt(d_k)
182
+ if mask is not None:
183
+ scores = scores.masked_fill(mask == 0, float('-inf'))
184
+ p_attn = F.softmax(scores, dim = -1)
185
+ if dropout is not None:
186
+ p_attn = dropout(p_attn)
187
+ return torch.matmul(p_attn, value), p_attn
188
+
189
+ class MultiHeadedAttention(nn.Module):
190
+ def __init__(self, h, d_model, dropout=0.1):
191
+ "Take in model size and number of heads."
192
+ super(MultiHeadedAttention, self).__init__()
193
+ assert d_model % h == 0
194
+ # We assume d_v always equals d_k
195
+ self.d_k = d_model // h
196
+ self.h = h
197
+ self.linears = clones(nn.Linear(d_model, d_model), 4)
198
+ self.attn = None
199
+ self.dropout = nn.Dropout(p=dropout)
200
+
201
+ def forward(self, query, key, value, mask=None, layer_past=None):
202
+ "Implements Figure 2"
203
+ if mask is not None:
204
+ # Same mask applied to all h heads.
205
+ mask = mask.unsqueeze(1)
206
+ nbatches = query.size(0)
207
+
208
+ # The past works differently here. For self attn, the query and key be updated incrementailly
209
+ # For src_attn the past is fixed.
210
+
211
+ # For src_attn, when the layer past is ready
212
+ if layer_past is not None and layer_past.shape[2] == key.shape[1] > 1: # suppose memory size always greater than 1
213
+ query = self.linears[0](query)
214
+ key, value = layer_past[0], layer_past[1]
215
+ present = torch.stack([key, value])
216
+ else:
217
+ # 1) Do all the linear projections in batch from d_model => h x d_k
218
+ query, key, value = \
219
+ [l(x) for l, x in zip(self.linears, (query, key, value))]
220
+
221
+ # self attn + past OR the first time step of src attn
222
+ if layer_past is not None and not (layer_past.shape[2] == key.shape[1] > 1):
223
+ past_key, past_value = layer_past[0], layer_past[1]
224
+ key = torch.cat((past_key, key), dim=1)
225
+ value = torch.cat((past_value, value), dim=1)
226
+ present = torch.stack([key, value])
227
+
228
+ query, key, value = \
229
+ [x.view(nbatches, -1, self.h, self.d_k).transpose(1, 2)
230
+ for x in [query, key, value]]
231
+
232
+ # 2) Apply attention on all the projected vectors in batch.
233
+ x, self.attn = attention(query, key, value, mask=mask,
234
+ dropout=self.dropout)
235
+
236
+ # 3) "Concat" using a view and apply a final linear.
237
+ x = x.transpose(1, 2).contiguous() \
238
+ .view(nbatches, -1, self.h * self.d_k)
239
+ if layer_past is not None:
240
+ return self.linears[-1](x), present
241
+ else:
242
+ return self.linears[-1](x)
243
+
244
+ class PositionwiseFeedForward(nn.Module):
245
+ "Implements FFN equation."
246
+ def __init__(self, d_model, d_ff, dropout=0.1):
247
+ super(PositionwiseFeedForward, self).__init__()
248
+ self.w_1 = nn.Linear(d_model, d_ff)
249
+ self.w_2 = nn.Linear(d_ff, d_model)
250
+ self.dropout = nn.Dropout(dropout)
251
+
252
+ def forward(self, x):
253
+ return self.w_2(self.dropout(F.relu(self.w_1(x))))
254
+
255
+ class Embeddings(nn.Module):
256
+ def __init__(self, d_model, vocab):
257
+ super(Embeddings, self).__init__()
258
+ self.lut = nn.Embedding(vocab, d_model)
259
+ self.d_model = d_model
260
+
261
+ def forward(self, x):
262
+ return self.lut(x) * math.sqrt(self.d_model)
263
+
264
+ class PositionalEncoding(nn.Module):
265
+ "Implement the PE function."
266
+ def __init__(self, d_model, dropout, max_len=5000):
267
+ super(PositionalEncoding, self).__init__()
268
+ self.dropout = nn.Dropout(p=dropout)
269
+
270
+ # Compute the positional encodings once in log space.
271
+ pe = torch.zeros(max_len, d_model)
272
+ position = torch.arange(0, max_len).unsqueeze(1).float()
273
+ div_term = torch.exp(torch.arange(0, d_model, 2).float() *
274
+ -(math.log(10000.0) / d_model))
275
+ pe[:, 0::2] = torch.sin(position * div_term)
276
+ pe[:, 1::2] = torch.cos(position * div_term)
277
+ pe = pe.unsqueeze(0)
278
+ self.register_buffer('pe', pe)
279
+
280
+ def forward(self, x):
281
+ x = x + self.pe[:, :x.size(1)]
282
+ return self.dropout(x)
283
+
284
+ class TransformerModel(AttModel):
285
+
286
+ def make_model(self, src_vocab, tgt_vocab, N_enc=6, N_dec=6,
287
+ d_model=512, d_ff=2048, h=8, dropout=0.1):
288
+ "Helper: Construct a model from hyperparameters."
289
+ c = copy.deepcopy
290
+ attn = MultiHeadedAttention(h, d_model, dropout)
291
+ ff = PositionwiseFeedForward(d_model, d_ff, dropout)
292
+ position = PositionalEncoding(d_model, dropout)
293
+ model = EncoderDecoder(
294
+ Encoder(EncoderLayer(d_model, c(attn), c(ff), dropout), N_enc),
295
+ Decoder(DecoderLayer(d_model, c(attn), c(attn),
296
+ c(ff), dropout), N_dec),
297
+ lambda x:x, # nn.Sequential(Embeddings(d_model, src_vocab), c(position)),
298
+ nn.Sequential(Embeddings(d_model, tgt_vocab), c(position)),
299
+ Generator(d_model, tgt_vocab))
300
+
301
+ # This was important from their code.
302
+ # Initialize parameters with Glorot / fan_avg.
303
+ for p in model.parameters():
304
+ if p.dim() > 1:
305
+ nn.init.xavier_uniform_(p)
306
+ return model
307
+
308
+ def __init__(self, opt):
309
+ super(TransformerModel, self).__init__(opt)
310
+ self.opt = opt
311
+ # self.config = yaml.load(open(opt.config_file))
312
+
313
+ self.N_enc = getattr(opt, 'N_enc', opt.num_layers)
314
+ self.N_dec = getattr(opt, 'N_dec', opt.num_layers)
315
+ self.d_model = getattr(opt, 'd_model', opt.input_encoding_size)
316
+ self.d_ff = getattr(opt, 'd_ff', opt.rnn_size)
317
+ self.h = getattr(opt, 'num_att_heads', 8)
318
+ self.dropout = getattr(opt, 'dropout', 0.1)
319
+
320
+ delattr(self, 'att_embed')
321
+ self.att_embed = nn.Sequential(*(
322
+ ((nn.BatchNorm1d(self.att_feat_size),) if self.use_bn else ())+
323
+ (nn.Linear(self.att_feat_size, self.d_model),
324
+ nn.ReLU(),
325
+ nn.Dropout(self.drop_prob_lm))+
326
+ ((nn.BatchNorm1d(self.d_model),) if self.use_bn==2 else ())))
327
+
328
+ delattr(self, 'embed')
329
+ self.embed = lambda x : x
330
+ delattr(self, 'fc_embed')
331
+ self.fc_embed = lambda x : x
332
+ delattr(self, 'logit')
333
+ del self.ctx2att
334
+
335
+ tgt_vocab = self.vocab_size + 1
336
+
337
+
338
+ self.model = self.make_model(0, tgt_vocab,
339
+ N_enc=self.N_enc,
340
+ N_dec=self.N_dec,
341
+ d_model=self.d_model,
342
+ d_ff=self.d_ff,
343
+ h=self.h,
344
+ dropout=self.dropout)
345
+
346
+ def logit(self, x): # unsafe way
347
+ return self.model.generator.proj(x)
348
+
349
+ def init_hidden(self, bsz):
350
+ return []
351
+
352
+ def _prepare_feature(self, fc_feats, att_feats, att_masks):
353
+
354
+ att_feats, seq, att_masks, seq_mask = self._prepare_feature_forward(att_feats, att_masks)
355
+ memory = self.model.encode(att_feats, att_masks)
356
+
357
+ return fc_feats[...,:0], att_feats[...,:0], memory, att_masks
358
+
359
+ def _prepare_feature_forward(self, att_feats, att_masks=None, seq=None):
360
+ att_feats, att_masks = self.clip_att(att_feats, att_masks)
361
+
362
+ att_feats = pack_wrapper(self.att_embed, att_feats, att_masks)
363
+
364
+ if att_masks is None:
365
+ att_masks = att_feats.new_ones(att_feats.shape[:2], dtype=torch.long)
366
+ att_masks = att_masks.unsqueeze(-2)
367
+
368
+ if seq is not None:
369
+ # crop the last one
370
+ # seq = seq[:,:-1]
371
+ seq_mask = (seq.data != self.eos_idx) & (seq.data != self.pad_idx)
372
+ seq_mask[:,0] = 1 # bos
373
+
374
+ seq_mask = seq_mask.unsqueeze(-2)
375
+ seq_mask = seq_mask & subsequent_mask(seq.size(-1)).to(seq_mask)
376
+
377
+ seq_per_img = seq.shape[0] // att_feats.shape[0]
378
+ if seq_per_img > 1:
379
+ att_feats, att_masks = utils.repeat_tensors(seq_per_img,
380
+ [att_feats, att_masks]
381
+ )
382
+ else:
383
+ seq_mask = None
384
+
385
+ return att_feats, seq, att_masks, seq_mask
386
+
387
+ def _forward(self, fc_feats, att_feats, seq, att_masks=None):
388
+ if seq.ndim == 3: # B * seq_per_img * seq_len
389
+ seq = seq.reshape(-1, seq.shape[2])
390
+ att_feats, seq, att_masks, seq_mask = self._prepare_feature_forward(att_feats, att_masks, seq)
391
+
392
+ out = self.model(att_feats, seq, att_masks, seq_mask)
393
+
394
+ outputs = self.model.generator(out)
395
+ return outputs
396
+ # return torch.cat([_.unsqueeze(1) for _ in outputs], 1)
397
+
398
+ def core(self, it, fc_feats_ph, att_feats_ph, memory, state, mask):
399
+ """
400
+ state is the precomputed key/value. N_dec x seq_len x d_model
401
+ Note: due to the layer norm, it's not equivalant to stateless,
402
+ but it seems behaving similar
403
+ """
404
+ # state is tokens + past
405
+ if len(state) == 0:
406
+ ys = it.unsqueeze(1)
407
+ # basically empty state, just to let it know to return past
408
+ # The second dim has to be batch_size, for beam search purpose
409
+ past = [fc_feats_ph.new_zeros(self.N_dec * 2, fc_feats_ph.shape[0], 0, self.d_model), # self
410
+ fc_feats_ph.new_zeros(self.N_dec * 2, fc_feats_ph.shape[0], 0, self.d_model)] # src
411
+ # 2 for self attn, 2 for src attn
412
+ else:
413
+ ys = torch.cat([state[0][0], it.unsqueeze(1)], dim=1)
414
+ past = state[1:]
415
+ out, past = self.model.decode(memory, mask,
416
+ ys, # We still feed the full past words, because we need it for position embedding to know the position id
417
+ subsequent_mask(ys.size(1))
418
+ .to(memory.device),
419
+ past=past)
420
+ return out[:, -1], [ys.unsqueeze(0)] + past
captioning/models/utils.py ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+
3
+ def repeat_tensors(n, x):
4
+ """
5
+ For a tensor of size Bx..., we repeat it n times, and make it Bnx...
6
+ For collections, do nested repeat
7
+ """
8
+ if torch.is_tensor(x):
9
+ x = x.unsqueeze(1) # Bx1x...
10
+ x = x.expand(-1, n, *([-1]*len(x.shape[2:]))) # Bxnx...
11
+ x = x.reshape(x.shape[0]*n, *x.shape[2:]) # Bnx...
12
+ elif type(x) is list or type(x) is tuple:
13
+ x = [repeat_tensors(n, _) for _ in x]
14
+ return x
15
+
16
+
17
+ def split_tensors(n, x):
18
+ if torch.is_tensor(x):
19
+ assert x.shape[0] % n == 0
20
+ x = x.reshape(x.shape[0] // n, n, *x.shape[1:]).unbind(1)
21
+ elif type(x) is list or type(x) is tuple:
22
+ x = [split_tensors(n, _) for _ in x]
23
+ elif x is None:
24
+ x = [None] * n
25
+ return x
captioning/modules/loss_wrapper.py ADDED
@@ -0,0 +1,127 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from . import losses
3
+ from ..utils.rewards import init_scorer, get_self_critical_reward, get_self_critical_clipscore_reward
4
+ from ..utils.clipscore import CLIPScore
5
+ import numpy as np
6
+
7
+ class LossWrapper(torch.nn.Module):
8
+ def __init__(self, model, opt):
9
+ super(LossWrapper, self).__init__()
10
+ self.opt = opt
11
+ self.model = model
12
+ if opt.label_smoothing > 0:
13
+ self.crit = losses.LabelSmoothing(smoothing=opt.label_smoothing)
14
+ else:
15
+ self.crit = losses.LanguageModelCriterion()
16
+ self.rl_crit = losses.RewardCriterion()
17
+ self.struc_crit = losses.StructureLosses(opt)
18
+
19
+ self.clipscore_model = None
20
+ if self.opt.use_clipscore:
21
+ use_grammar = getattr(self.opt, 'use_grammar', False)
22
+ joint_out = getattr(self.opt, 'joint_out', False)
23
+ self.clipscore_model = CLIPScore(
24
+ mode=opt.clipscore_mode,
25
+ use_grammar=use_grammar,
26
+ joint_out=joint_out,
27
+ )
28
+ for p in self.clipscore_model.parameters():
29
+ p.requires_grad = False
30
+
31
+ if use_grammar:
32
+ state_dict = torch.load(self.opt.clip_load_path, map_location='cpu')
33
+ self.clipscore_model.load_state_dict(state_dict['state_dict'])
34
+
35
+ def forward(self, fc_feats, att_feats, labels, masks, att_masks, gts, gt_indices,
36
+ sc_flag, struc_flag, clip_vis_feats=None):
37
+ opt = self.opt
38
+
39
+ out = {}
40
+ if struc_flag:
41
+ if opt.structure_loss_weight < 1:
42
+ lm_loss = self.crit(self.model(fc_feats, att_feats, labels[..., :-1], att_masks), labels[..., 1:], masks[..., 1:])
43
+ else:
44
+ lm_loss = torch.tensor(0).type_as(fc_feats)
45
+ if opt.structure_loss_weight > 0:
46
+ gen_result, sample_logprobs = self.model(fc_feats, att_feats, att_masks,
47
+ opt={'sample_method':opt.train_sample_method,
48
+ 'beam_size':opt.train_beam_size,
49
+ 'output_logsoftmax': opt.struc_use_logsoftmax or opt.structure_loss_type == 'softmax_margin'\
50
+ or not 'margin' in opt.structure_loss_type,
51
+ 'sample_n': opt.train_sample_n},
52
+ mode='sample')
53
+ gts = [gts[_] for _ in gt_indices.tolist()]
54
+ struc_loss = self.struc_crit(sample_logprobs, gen_result, gts)
55
+ else:
56
+ struc_loss = {'loss': torch.tensor(0).type_as(fc_feats),
57
+ 'reward': torch.tensor(0).type_as(fc_feats)}
58
+ loss = (1-opt.structure_loss_weight) * lm_loss + opt.structure_loss_weight * struc_loss['loss']
59
+ out['lm_loss'] = lm_loss
60
+ out['struc_loss'] = struc_loss['loss']
61
+ out['reward'] = struc_loss['reward']
62
+ elif not sc_flag:
63
+ loss = self.crit(self.model(fc_feats, att_feats, labels[..., :-1], att_masks), labels[..., 1:], masks[..., 1:])
64
+ else:
65
+ self.model.eval()
66
+ with torch.no_grad():
67
+ greedy_res, _ = self.model(fc_feats, att_feats, att_masks,
68
+ mode='sample',
69
+ opt={'sample_method': opt.sc_sample_method,
70
+ 'beam_size': opt.sc_beam_size})
71
+ self.model.train()
72
+ gen_result, sample_logprobs = self.model(fc_feats, att_feats, att_masks,
73
+ opt={'sample_method':opt.train_sample_method,
74
+ 'beam_size':opt.train_beam_size,
75
+ 'sample_n': opt.train_sample_n},
76
+ mode='sample')
77
+ gts = [gts[_] for _ in gt_indices.tolist()]
78
+
79
+ if getattr(self.opt, 'use_multi_rewards', False):
80
+ assert self.opt.use_clipscore
81
+ clipscore_reward_normalized, clipscore_unnormalized_mean, grammar_rewards = get_self_critical_clipscore_reward(
82
+ greedy_res, gts, gen_result, self.opt, self.clipscore_model, clip_vis_feats, self.model.vocab)
83
+
84
+ if self.opt.clipscore_mode == 'clip_s':
85
+ out['CLIP-S'] = clipscore_unnormalized_mean
86
+ elif self.opt.clipscore_mode == 'refclip_s':
87
+ out['RefCLIP-S'] = clipscore_unnormalized_mean
88
+
89
+ if getattr(self.opt, 'use_grammar', False):
90
+ out['grammar_reward'] = grammar_rewards.mean()
91
+
92
+ reward = clipscore_reward_normalized + grammar_rewards
93
+
94
+
95
+ else:
96
+ assert grammar_rewards is None
97
+
98
+ cider_reward_normalized, cider_unnormalized_mean = get_self_critical_reward(
99
+ greedy_res, gts, gen_result, self.opt)
100
+ out['CIDEr'] = cider_unnormalized_mean
101
+ if isinstance(cider_reward_normalized, np.ndarray):
102
+ cider_reward_normalized = torch.from_numpy(cider_reward_normalized).to(clipscore_reward_normalized.device)
103
+
104
+ reward = clipscore_reward_normalized + cider_reward_normalized
105
+ else:
106
+ if self.opt.use_clipscore:
107
+ clipscore_reward_normalized, clipscore_unnormalized_mean, _ = get_self_critical_clipscore_reward(
108
+ greedy_res, gts, gen_result, self.opt, self.clipscore_model, clip_vis_feats, self.model.vocab)
109
+ if self.opt.clipscore_mode == 'clip_s':
110
+ out['CLIP-S'] = clipscore_unnormalized_mean
111
+ elif self.opt.clipscore_mode == 'refclip_s':
112
+ out['RefCLIP-S'] = clipscore_unnormalized_mean
113
+ reward = clipscore_reward_normalized
114
+ else:
115
+ cider_reward_normalized, cider_unnormalized_mean = get_self_critical_reward(
116
+ greedy_res, gts, gen_result, self.opt)
117
+ out['CIDEr'] = cider_unnormalized_mean
118
+ reward = cider_reward_normalized
119
+
120
+ if isinstance(reward, np.ndarray):
121
+ reward = torch.from_numpy(reward)
122
+ reward = reward.to(sample_logprobs)
123
+ loss = self.rl_crit(sample_logprobs, gen_result.data, reward)
124
+ out['reward'] = reward[:,0].mean()
125
+ out['loss'] = loss
126
+ return out
127
+
captioning/modules/losses.py ADDED
@@ -0,0 +1,218 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ from ..utils.rewards import get_scores, get_self_cider_scores
4
+
5
+ class RewardCriterion(nn.Module):
6
+ def __init__(self):
7
+ super(RewardCriterion, self).__init__()
8
+
9
+ def forward(self, input, seq, reward):
10
+ input = input.gather(2, seq.unsqueeze(2)).squeeze(2)
11
+
12
+ input = input.reshape(-1)
13
+ reward = reward.reshape(-1)
14
+ mask = (seq>0).to(input)
15
+ mask = torch.cat([mask.new(mask.size(0), 1).fill_(1), mask[:, :-1]], 1).reshape(-1)
16
+ output = - input * reward * mask
17
+ output = torch.sum(output) / torch.sum(mask)
18
+
19
+ return output
20
+
21
+ class StructureLosses(nn.Module):
22
+ """
23
+ This loss is inspired by Classical Structured Prediction Losses for Sequence to Sequence Learning (Edunov et al., 2018).
24
+ """
25
+ def __init__(self, opt):
26
+ super(StructureLosses, self).__init__()
27
+ self.opt = opt
28
+ self.loss_type = opt.structure_loss_type
29
+
30
+ def forward(self, input, seq, data_gts):
31
+ """
32
+ Input is either logits or log softmax
33
+ """
34
+ out = {}
35
+
36
+ batch_size = input.size(0)# batch_size = sample_size * seq_per_img
37
+ seq_per_img = batch_size // len(data_gts)
38
+
39
+ assert seq_per_img == self.opt.train_sample_n, seq_per_img
40
+
41
+ mask = (seq>0).to(input)
42
+ mask = torch.cat([mask.new_full((mask.size(0), 1), 1), mask[:, :-1]], 1)
43
+
44
+ scores = get_scores(data_gts, seq, self.opt)
45
+ scores = torch.from_numpy(scores).type_as(input).view(-1, seq_per_img)
46
+ out['reward'] = scores #.mean()
47
+ if self.opt.entropy_reward_weight > 0:
48
+ entropy = - (F.softmax(input, dim=2) * F.log_softmax(input, dim=2)).sum(2).data
49
+ entropy = (entropy * mask).sum(1) / mask.sum(1)
50
+ print('entropy', entropy.mean().item())
51
+ scores = scores + self.opt.entropy_reward_weight * entropy.view(-1, seq_per_img)
52
+ # rescale cost to [0,1]
53
+ costs = - scores
54
+ if self.loss_type == 'risk' or self.loss_type == 'softmax_margin':
55
+ costs = costs - costs.min(1, keepdim=True)[0]
56
+ costs = costs / costs.max(1, keepdim=True)[0]
57
+ # in principle
58
+ # Only risk need such rescale
59
+ # margin should be alright; Let's try.
60
+
61
+ # Gather input: BxTxD -> BxT
62
+ input = input.gather(2, seq.unsqueeze(2)).squeeze(2)
63
+
64
+ if self.loss_type == 'seqnll':
65
+ # input is logsoftmax
66
+ input = input * mask
67
+ input = input.sum(1) / mask.sum(1)
68
+ input = input.view(-1, seq_per_img)
69
+
70
+ target = costs.min(1)[1]
71
+ output = F.cross_entropy(input, target)
72
+ elif self.loss_type == 'risk':
73
+ # input is logsoftmax
74
+ input = input * mask
75
+ input = input.sum(1)
76
+ input = input.view(-1, seq_per_img)
77
+
78
+ output = (F.softmax(input.exp()) * costs).sum(1).mean()
79
+
80
+ # test
81
+ # avg_scores = input
82
+ # probs = F.softmax(avg_scores.exp_())
83
+ # loss = (probs * costs.type_as(probs)).sum() / input.size(0)
84
+ # print(output.item(), loss.item())
85
+
86
+ elif self.loss_type == 'max_margin':
87
+ # input is logits
88
+ input = input * mask
89
+ input = input.sum(1) / mask.sum(1)
90
+ input = input.view(-1, seq_per_img)
91
+ _, __ = costs.min(1, keepdim=True)
92
+ costs_star = _
93
+ input_star = input.gather(1, __)
94
+ output = F.relu(costs - costs_star - input_star + input).max(1)[0] / 2
95
+ output = output.mean()
96
+
97
+ # sanity test
98
+ # avg_scores = input + costs
99
+ # scores_with_high_target = avg_scores.clone()
100
+ # scores_with_high_target.scatter_(1, costs.min(1)[1].view(-1, 1), 1e10)
101
+
102
+ # target_and_offender_index = scores_with_high_target.sort(1, True)[1][:, 0:2]
103
+ # avg_scores = avg_scores.gather(1, target_and_offender_index)
104
+ # target_index = avg_scores.new_zeros(avg_scores.size(0), dtype=torch.long)
105
+ # loss = F.multi_margin_loss(avg_scores, target_index, size_average=True, margin=0)
106
+ # print(loss.item() * 2, output.item())
107
+
108
+ elif self.loss_type == 'multi_margin':
109
+ # input is logits
110
+ input = input * mask
111
+ input = input.sum(1) / mask.sum(1)
112
+ input = input.view(-1, seq_per_img)
113
+ _, __ = costs.min(1, keepdim=True)
114
+ costs_star = _
115
+ input_star = input.gather(1, __)
116
+ output = F.relu(costs - costs_star - input_star + input)
117
+ output = output.mean()
118
+
119
+ # sanity test
120
+ # avg_scores = input + costs
121
+ # loss = F.multi_margin_loss(avg_scores, costs.min(1)[1], margin=0)
122
+ # print(output, loss)
123
+
124
+ elif self.loss_type == 'softmax_margin':
125
+ # input is logsoftmax
126
+ input = input * mask
127
+ input = input.sum(1) / mask.sum(1)
128
+ input = input.view(-1, seq_per_img)
129
+
130
+ input = input + costs
131
+ target = costs.min(1)[1]
132
+ output = F.cross_entropy(input, target)
133
+
134
+ elif self.loss_type == 'real_softmax_margin':
135
+ # input is logits
136
+ # This is what originally defined in Kevin's paper
137
+ # The result should be equivalent to softmax_margin
138
+ input = input * mask
139
+ input = input.sum(1) / mask.sum(1)
140
+ input = input.view(-1, seq_per_img)
141
+
142
+ input = input + costs
143
+ target = costs.min(1)[1]
144
+ output = F.cross_entropy(input, target)
145
+
146
+ elif self.loss_type == 'new_self_critical':
147
+ """
148
+ A different self critical
149
+ Self critical uses greedy decoding score as baseline;
150
+ This setting uses the average score of the rest samples as baseline
151
+ (suppose c1...cn n samples, reward1 = score1 - 1/(n-1)(score2+..+scoren) )
152
+ """
153
+ baseline = (scores.sum(1, keepdim=True) - scores) / (scores.shape[1] - 1)
154
+ scores = scores - baseline
155
+ # self cider used as reward to promote diversity (not working that much in this way)
156
+ if getattr(self.opt, 'self_cider_reward_weight', 0) > 0:
157
+ _scores = get_self_cider_scores(data_gts, seq, self.opt)
158
+ _scores = torch.from_numpy(_scores).type_as(scores).view(-1, 1)
159
+ _scores = _scores.expand_as(scores - 1)
160
+ scores += self.opt.self_cider_reward_weight * _scores
161
+ output = - input * mask * scores.view(-1, 1)
162
+ output = torch.sum(output) / torch.sum(mask)
163
+
164
+ out['loss'] = output
165
+ return out
166
+
167
+ class LanguageModelCriterion(nn.Module):
168
+ def __init__(self):
169
+ super(LanguageModelCriterion, self).__init__()
170
+
171
+ def forward(self, input, target, mask):
172
+ if target.ndim == 3:
173
+ target = target.reshape(-1, target.shape[2])
174
+ mask = mask.reshape(-1, mask.shape[2])
175
+ # truncate to the same size
176
+ target = target[:, :input.size(1)]
177
+ mask = mask[:, :input.size(1)].to(input)
178
+
179
+ output = -input.gather(2, target.unsqueeze(2)).squeeze(2) * mask
180
+ # Average over each token
181
+ output = torch.sum(output) / torch.sum(mask)
182
+
183
+ return output
184
+
185
+ class LabelSmoothing(nn.Module):
186
+ "Implement label smoothing."
187
+ def __init__(self, size=0, padding_idx=0, smoothing=0.0):
188
+ super(LabelSmoothing, self).__init__()
189
+ self.criterion = nn.KLDivLoss(size_average=False, reduce=False)
190
+ # self.padding_idx = padding_idx
191
+ self.confidence = 1.0 - smoothing
192
+ self.smoothing = smoothing
193
+ # self.size = size
194
+ self.true_dist = None
195
+
196
+ def forward(self, input, target, mask):
197
+ if target.ndim == 3:
198
+ target = target.reshape(-1, target.shape[2])
199
+ mask = mask.reshape(-1, mask.shape[2])
200
+ # truncate to the same size
201
+ target = target[:, :input.size(1)]
202
+ mask = mask[:, :input.size(1)]
203
+
204
+ input = input.reshape(-1, input.size(-1))
205
+ target = target.reshape(-1)
206
+ mask = mask.reshape(-1).to(input)
207
+
208
+ # assert x.size(1) == self.size
209
+ self.size = input.size(1)
210
+ # true_dist = x.data.clone()
211
+ true_dist = input.data.clone()
212
+ # true_dist.fill_(self.smoothing / (self.size - 2))
213
+ true_dist.fill_(self.smoothing / (self.size - 1))
214
+ true_dist.scatter_(1, target.data.unsqueeze(1), self.confidence)
215
+ # true_dist[:, self.padding_idx] = 0
216
+ # mask = torch.nonzero(target.data == self.padding_idx)
217
+ # self.true_dist = true_dist
218
+ return (self.criterion(input, true_dist).sum(1) * mask).sum() / mask.sum()
captioning/utils/__init__.py ADDED
File without changes
captioning/utils/clipscore.py ADDED
@@ -0,0 +1,396 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import CLIPModel, CLIPTokenizer
2
+ import os
3
+ import json
4
+ import argparse
5
+ from random import shuffle, seed
6
+ import string
7
+ # non-standard dependencies:
8
+ import h5py
9
+ from six.moves import cPickle
10
+ import numpy as np
11
+ import torch
12
+ import torchvision.models as models
13
+ import skimage.io
14
+
15
+ from torchvision.transforms import Compose, Resize, CenterCrop, ToTensor, Normalize
16
+ from PIL import Image
17
+ from torch import nn
18
+
19
+
20
+ class CLIPScore(nn.Module):
21
+ def __init__(self, clipscore_w=2.5, image_size=224, mode='clip_s', use_grammar=False, joint_out=False):
22
+ super(CLIPScore, self).__init__()
23
+ # from transformers import CLIPModel, CLIPTokenizer
24
+ self.clip_model = CLIPModel.from_pretrained(
25
+ 'openai/clip-vit-base-patch32')
26
+ self.tokenizer = CLIPTokenizer.from_pretrained(
27
+ 'openai/clip-vit-base-patch32')
28
+
29
+ self.clip_model.eval()
30
+
31
+ self.clipscore_w = clipscore_w
32
+
33
+ self.image_transform = self._transform(image_size)
34
+
35
+ self.mode = mode
36
+ assert mode in ['clip_s', 'refclip_s']
37
+
38
+ self.use_grammar = use_grammar
39
+ self.joint_out = joint_out
40
+
41
+ if self.use_grammar and joint_out is False:
42
+ self.grammar_score_head = nn.Sequential(
43
+ nn.Linear(self.clip_model.text_embed_dim, self.clip_model.projection_dim, bias=False),
44
+ nn.ReLU(),
45
+ nn.Linear(self.clip_model.projection_dim, 2, bias=False)
46
+ )
47
+
48
+ def _transform(self, n_px):
49
+ return Compose([
50
+ Resize(n_px, interpolation=Image.BICUBIC),
51
+ CenterCrop(n_px),
52
+ lambda image: image.convert("RGB"),
53
+ ToTensor(),
54
+ Normalize((0.48145466, 0.4578275, 0.40821073),
55
+ (0.26862954, 0.26130258, 0.27577711)),
56
+ ])
57
+
58
+ def load_image(self, image_path):
59
+ image = Image.open(image_path)
60
+ return image
61
+
62
+ # @torch.no_grad()
63
+ def image_extract(self, image):
64
+ if isinstance(image, str):
65
+ image = self.load_image(image)
66
+ if not isinstance(image, torch.Tensor):
67
+ image = self.image_transform(image)
68
+
69
+ img_tensor = image.view(-1, 3, 224, 224)
70
+ device = next(self.clip_model.parameters()).device
71
+ img_tensor = img_tensor.to(device)
72
+
73
+ clip_model = self.clip_model
74
+
75
+ img_feat = clip_model.vision_model(img_tensor).pooler_output
76
+ img_feat = clip_model.visual_projection(img_feat)
77
+ img_feat = img_feat / img_feat.norm(dim=-1, keepdim=True)
78
+
79
+ return img_feat
80
+
81
+ # @torch.no_grad()
82
+ def text_extract(self, text, prompt="A photo depicts", proj_norm=True):
83
+ if isinstance(text, str):
84
+ text_batch = [" ".join([prompt, text])]
85
+ elif isinstance(text, list):
86
+ text_batch = [" ".join([prompt, txt]) for txt in text]
87
+
88
+ if isinstance(text, tuple) and isinstance(text[0], torch.Tensor):
89
+ input_ids, attention_mask = text
90
+ else:
91
+ input_text = text_batch
92
+
93
+ tokenized = self.tokenizer(
94
+ input_text, return_tensors='pt', padding=True, truncation=True)
95
+
96
+ input_ids = tokenized.input_ids
97
+ attention_mask = tokenized.attention_mask
98
+
99
+ clip_model = self.clip_model
100
+ device = next(self.clip_model.parameters()).device
101
+ input_ids = input_ids.to(device)
102
+ attention_mask = attention_mask.to(device)
103
+
104
+ text_feat = clip_model.text_model(input_ids, attention_mask).pooler_output
105
+
106
+ if proj_norm:
107
+ text_feat = clip_model.text_projection(text_feat)
108
+ text_feat = text_feat / text_feat.norm(dim=-1, keepdim=True)
109
+
110
+ return text_feat
111
+
112
+ # @torch.no_grad()
113
+ def calc_clip_s(self, img_feat, text_feat):
114
+ return self.clipscore_w * torch.relu((img_feat * text_feat).sum(dim=-1))
115
+
116
+ # @torch.no_grad()
117
+ def calc_refclip_s(self, img_feat=None, text_feat=None, ref_text_feat=None, ref_text_mask=None, clip_s=None):
118
+
119
+ if clip_s is None:
120
+ clip_s = self.calc_clip_s(img_feat, text_feat)
121
+
122
+ B, dim = img_feat.size()
123
+
124
+ ref_text_feat = ref_text_feat.view(B, -1, dim)
125
+
126
+ K = ref_text_feat.size(1)
127
+
128
+ text_feat = text_feat.view(B, 1, dim).expand(-1, K, -1)
129
+ assert ref_text_feat.size() == text_feat.size(
130
+ ), (ref_text_feat.size(), text_feat.size())
131
+
132
+ ref_score = self.calc_clip_s(text_feat, ref_text_feat)
133
+ if ref_text_mask is not None:
134
+ if not isinstance(ref_text_mask, torch.Tensor):
135
+ ref_text_mask = torch.tensor(
136
+ ref_text_mask, dtype=ref_score.dtype, device=ref_score.device)
137
+ ref_score = ref_score.view(B, K) * ref_text_mask.view(B, K)
138
+
139
+ ref_score = ref_score.view(B, K).max(dim=1).values
140
+
141
+ assert clip_s.size() == (B,)
142
+ assert clip_s.size() == ref_score.size()
143
+
144
+ # harmonic mean
145
+ refclip_s = 2 / (1 / clip_s + 1 / ref_score)
146
+ return refclip_s
147
+
148
+ @torch.no_grad()
149
+ def forward(self,
150
+ images=None, text=None,
151
+ img_feat=None, text_feat=None,
152
+ ref_text=None, ref_text_feat=None, ref_text_mask=None,
153
+ prompt="A photo depicts",
154
+ mode=None):
155
+ if img_feat is None:
156
+ img_feat = self.image_extract(images)
157
+ img_feat = img_feat.view(-1, 512)
158
+
159
+ B = img_feat.size(0)
160
+
161
+ if text_feat is None:
162
+ text_feat = self.text_extract(text, prompt=prompt)
163
+ text_feat = text_feat.view(-1, 512)
164
+
165
+ if mode is None:
166
+ mode = self.mode
167
+ assert mode in ['clip_s', 'refclip_s']
168
+
169
+ if mode == 'clip_s':
170
+ clip_s = self.calc_clip_s(img_feat, text_feat)
171
+ return clip_s
172
+ elif mode == 'refclip_s':
173
+ if ref_text_feat is None:
174
+ ref_text_feat = self.text_extract(ref_text, prompt=prompt)
175
+ ref_text_feat = ref_text_feat.view(-1, 512)
176
+
177
+ refclip_s = self.calc_refclip_s(
178
+ img_feat, text_feat, ref_text_feat, ref_text_mask=ref_text_mask)
179
+ return refclip_s
180
+
181
+
182
+ def train_step(self,
183
+ images=None, text=None,
184
+ img_feat=None, text_feat=None,
185
+ neg_text=None, neg_text_feat=None,
186
+ # ref_text=None, ref_text_feat=None, ref_text_mask=None,
187
+ prompt="A photo depicts",
188
+ # return_loss=True,
189
+ **kwargs):
190
+
191
+ if img_feat is None:
192
+ img_feat = self.image_extract(images)
193
+ img_feat = img_feat.view(-1, 512)
194
+
195
+ B = img_feat.size(0)
196
+
197
+ if text_feat is None:
198
+ text_feat = self.text_extract(text, prompt=prompt, proj_norm=False)
199
+
200
+ text_cont_feat = self.clip_model.text_projection(text_feat)
201
+ text_cont_feat = text_cont_feat / text_cont_feat.norm(dim=-1, keepdim=True)
202
+ text_cont_feat = text_cont_feat.view(B, 512)
203
+
204
+ # cosine similarity as logits
205
+ logit_scale = self.clip_model.logit_scale.exp()
206
+ logits_per_text = torch.matmul(text_cont_feat, img_feat.t()) * logit_scale
207
+ # logits_per_image = logits_per_text.T
208
+
209
+ clip_loss = clip_loss_fn(logits_per_text)
210
+
211
+
212
+ # negative sampling
213
+ pos_text_feat = text_feat.view(B, 512)
214
+ neg_text_feat = self.text_extract(neg_text, prompt=prompt, proj_norm=False).view(B, 512)
215
+
216
+ grammar_text_feat = torch.cat([pos_text_feat, neg_text_feat], dim=0)
217
+
218
+ # 2B, 1
219
+ grammar_text_logit = self.grammar_score_head(grammar_text_feat)
220
+ grammar_labels = torch.LongTensor([1] * B + [0] * B).to(grammar_text_logit.device).view(2 * B)
221
+
222
+ grammar_loss = torch.nn.functional.cross_entropy(grammar_text_logit, grammar_labels)
223
+
224
+ grammar_pred = grammar_text_logit.argmax(dim=1, keepdim=False)
225
+ grammar_pos_pred = grammar_pred[:B]
226
+ grammar_neg_pred = grammar_pred[B:]
227
+ # grammar_acc = (grammar_pred == grammar_labels).float().mean()
228
+
229
+ out = {
230
+ 'clip_loss': clip_loss,
231
+ 'grammar_loss': grammar_loss,
232
+ 'img_feat': img_feat,
233
+ 'text_feat': text_cont_feat,
234
+ 'neg_text_feat': neg_text_feat,
235
+ 'grammar_pos_pred': grammar_pos_pred,
236
+ 'grammar_neg_pred': grammar_neg_pred,
237
+ }
238
+
239
+ return out
240
+
241
+ # contrastive loss function, adapted from
242
+ # https://sachinruk.github.io/blog/pytorch/pytorch%20lightning/loss%20function/gpu/2021/03/07/CLIP.html
243
+ def contrastive_loss(logits: torch.Tensor, dim: int) -> torch.Tensor:
244
+ neg_ce = torch.diag(nn.functional.log_softmax(logits, dim=dim))
245
+ return -neg_ce.mean()
246
+
247
+
248
+ def clip_loss_fn(similarity: torch.Tensor) -> torch.Tensor:
249
+ caption_loss = contrastive_loss(similarity, dim=0)
250
+ image_loss = contrastive_loss(similarity, dim=1)
251
+ return (caption_loss + image_loss) / 2.0
252
+
253
+
254
+
255
+ # class CLIPScore(nn.Module):
256
+ # def __init__(self, clipscore_w=2.5, image_size=224, mode='clip_s'):
257
+ # super(CLIPScore, self).__init__()
258
+ # # from transformers import CLIPModel, CLIPTokenizer
259
+ # self.clip_model = CLIPModel.from_pretrained(
260
+ # 'openai/clip-vit-base-patch32')
261
+ # self.tokenizer = CLIPTokenizer.from_pretrained(
262
+ # 'openai/clip-vit-base-patch32')
263
+
264
+ # self.clip_model.eval()
265
+
266
+ # self.clipscore_w = clipscore_w
267
+
268
+ # self.image_transform = self._transform(image_size)
269
+
270
+ # self.mode = mode
271
+ # assert mode in ['clip_s', 'refclip_s']
272
+
273
+ # def _transform(self, n_px):
274
+ # return Compose([
275
+ # Resize(n_px, interpolation=Image.BICUBIC),
276
+ # CenterCrop(n_px),
277
+ # lambda image: image.convert("RGB"),
278
+ # ToTensor(),
279
+ # Normalize((0.48145466, 0.4578275, 0.40821073),
280
+ # (0.26862954, 0.26130258, 0.27577711)),
281
+ # ])
282
+
283
+ # def load_image(self, image_path):
284
+ # image = Image.open(image_path)
285
+ # return image
286
+
287
+ # @torch.no_grad()
288
+ # def image_extract(self, image):
289
+ # if isinstance(image, str):
290
+ # image = self.load_image(image)
291
+ # if not isinstance(image, torch.Tensor):
292
+ # image = self.image_transform(image)
293
+
294
+ # img_tensor = image.view(-1, 3, 224, 224)
295
+ # device = next(self.clip_model.parameters()).device
296
+ # img_tensor = img_tensor.to(device)
297
+
298
+ # clip_model = self.clip_model
299
+
300
+ # img_feat = clip_model.vision_model(img_tensor).pooler_output
301
+ # img_feat = clip_model.visual_projection(img_feat)
302
+ # img_feat = img_feat / img_feat.norm(dim=-1, keepdim=True)
303
+
304
+ # return img_feat
305
+
306
+ # @torch.no_grad()
307
+ # def text_extract(self, text, prompt="A photo depicts"):
308
+ # if isinstance(text, str):
309
+ # text_batch = [" ".join([prompt, text])]
310
+ # else:
311
+ # text_batch = [" ".join([prompt, txt]) for txt in text]
312
+
313
+ # input_text = text_batch
314
+
315
+ # tokenized = self.tokenizer(
316
+ # input_text, return_tensors='pt', padding=True)
317
+
318
+ # input_ids = tokenized.input_ids
319
+ # attention_mask = tokenized.attention_mask
320
+
321
+ # clip_model = self.clip_model
322
+ # device = next(self.clip_model.parameters()).device
323
+ # input_ids = input_ids.to(device)
324
+ # attention_mask = attention_mask.to(device)
325
+
326
+ # text_feat = clip_model.text_model(input_ids, attention_mask).pooler_output
327
+ # text_feat = clip_model.text_projection(text_feat)
328
+ # text_feat = text_feat / text_feat.norm(dim=-1, keepdim=True)
329
+
330
+ # return text_feat
331
+
332
+ # @torch.no_grad()
333
+ # def calc_clip_s(self, img_feat, text_feat):
334
+ # return self.clipscore_w * torch.relu((img_feat * text_feat).sum(dim=-1))
335
+
336
+ # @torch.no_grad()
337
+ # def calc_refclip_s(self, img_feat=None, text_feat=None, ref_text_feat=None, ref_text_mask=None, clip_s=None):
338
+
339
+ # if clip_s is None:
340
+ # clip_s = self.calc_clip_s(img_feat, text_feat)
341
+
342
+ # B, dim = img_feat.size()
343
+
344
+ # ref_text_feat = ref_text_feat.view(B, -1, dim)
345
+
346
+ # K = ref_text_feat.size(1)
347
+
348
+ # text_feat = text_feat.view(B, 1, dim).expand(-1, K, -1)
349
+ # assert ref_text_feat.size() == text_feat.size(), (ref_text_feat.size(), text_feat.size())
350
+
351
+ # ref_score = self.calc_clip_s(text_feat, ref_text_feat)
352
+ # if ref_text_mask is not None:
353
+ # if not isinstance(ref_text_mask, torch.Tensor):
354
+ # ref_text_mask = torch.tensor(ref_text_mask, dtype=ref_score.dtype, device=ref_score.device)
355
+ # ref_score = ref_score.view(B, K) * ref_text_mask.view(B, K)
356
+
357
+ # ref_score = ref_score.view(B, K).max(dim=1).values
358
+
359
+ # assert clip_s.size() == (B,)
360
+ # assert clip_s.size() == ref_score.size()
361
+
362
+ # # harmonic mean
363
+ # refclip_s = 2 / (1 / clip_s + 1 / ref_score)
364
+ # return refclip_s
365
+
366
+
367
+ # @torch.no_grad()
368
+ # def forward(self,
369
+ # images=None, text=None,
370
+ # img_feat=None, text_feat=None,
371
+ # ref_text=None, ref_text_feat=None, ref_text_mask=None,
372
+ # prompt="A photo depicts",
373
+ # mode=None):
374
+ # if img_feat is None:
375
+ # img_feat = self.image_extract(images)
376
+ # img_feat = img_feat.view(-1, 512)
377
+
378
+ # if text_feat is None:
379
+ # text_feat = self.text_extract(text, prompt=prompt)
380
+ # text_feat = text_feat.view(-1, 512)
381
+
382
+ # if mode is None:
383
+ # mode = self.mode
384
+ # assert mode in ['clip_s', 'refclip_s']
385
+
386
+ # if mode == 'clip_s':
387
+ # clip_s = self.calc_clip_s(img_feat, text_feat)
388
+ # return clip_s
389
+ # elif mode == 'refclip_s':
390
+ # if ref_text_feat is None:
391
+ # ref_text_feat = self.text_extract(ref_text, prompt=prompt)
392
+ # ref_text_feat = ref_text_feat.view(-1, 512)
393
+
394
+ # refclip_s = self.calc_refclip_s(img_feat, text_feat, ref_text_feat, ref_text_mask=ref_text_mask)
395
+ # return refclip_s
396
+
captioning/utils/config.py ADDED
@@ -0,0 +1,153 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
2
+ # Copy from fvcore
3
+
4
+ import logging
5
+ import os
6
+ from typing import Any
7
+ import yaml
8
+ from yacs.config import CfgNode as _CfgNode
9
+
10
+ import io as PathManager
11
+
12
+ BASE_KEY = "_BASE_"
13
+
14
+
15
+ class CfgNode(_CfgNode):
16
+ """
17
+ Our own extended version of :class:`yacs.config.CfgNode`.
18
+ It contains the following extra features:
19
+
20
+ 1. The :meth:`merge_from_file` method supports the "_BASE_" key,
21
+ which allows the new CfgNode to inherit all the attributes from the
22
+ base configuration file.
23
+ 2. Keys that start with "COMPUTED_" are treated as insertion-only
24
+ "computed" attributes. They can be inserted regardless of whether
25
+ the CfgNode is frozen or not.
26
+ 3. With "allow_unsafe=True", it supports pyyaml tags that evaluate
27
+ expressions in config. See examples in
28
+ https://pyyaml.org/wiki/PyYAMLDocumentation#yaml-tags-and-python-types
29
+ Note that this may lead to arbitrary code execution: you must not
30
+ load a config file from untrusted sources before manually inspecting
31
+ the content of the file.
32
+ """
33
+
34
+ @staticmethod
35
+ def load_yaml_with_base(filename, allow_unsafe = False):
36
+ """
37
+ Just like `yaml.load(open(filename))`, but inherit attributes from its
38
+ `_BASE_`.
39
+
40
+ Args:
41
+ filename (str): the file name of the current config. Will be used to
42
+ find the base config file.
43
+ allow_unsafe (bool): whether to allow loading the config file with
44
+ `yaml.unsafe_load`.
45
+
46
+ Returns:
47
+ (dict): the loaded yaml
48
+ """
49
+ with PathManager.open(filename, "r") as f:
50
+ try:
51
+ cfg = yaml.safe_load(f)
52
+ except yaml.constructor.ConstructorError:
53
+ if not allow_unsafe:
54
+ raise
55
+ logger = logging.getLogger(__name__)
56
+ logger.warning(
57
+ "Loading config {} with yaml.unsafe_load. Your machine may "
58
+ "be at risk if the file contains malicious content.".format(
59
+ filename
60
+ )
61
+ )
62
+ f.close()
63
+ with open(filename, "r") as f:
64
+ cfg = yaml.unsafe_load(f)
65
+
66
+ def merge_a_into_b(a, b):
67
+ # merge dict a into dict b. values in a will overwrite b.
68
+ for k, v in a.items():
69
+ if isinstance(v, dict) and k in b:
70
+ assert isinstance(
71
+ b[k], dict
72
+ ), "Cannot inherit key '{}' from base!".format(k)
73
+ merge_a_into_b(v, b[k])
74
+ else:
75
+ b[k] = v
76
+
77
+ if BASE_KEY in cfg:
78
+ base_cfg_file = cfg[BASE_KEY]
79
+ if base_cfg_file.startswith("~"):
80
+ base_cfg_file = os.path.expanduser(base_cfg_file)
81
+ if not any(
82
+ map(base_cfg_file.startswith, ["/", "https://", "http://"])
83
+ ):
84
+ # the path to base cfg is relative to the config file itself.
85
+ base_cfg_file = os.path.join(
86
+ os.path.dirname(filename), base_cfg_file
87
+ )
88
+ base_cfg = CfgNode.load_yaml_with_base(
89
+ base_cfg_file, allow_unsafe=allow_unsafe
90
+ )
91
+ del cfg[BASE_KEY]
92
+
93
+ merge_a_into_b(cfg, base_cfg)
94
+ return base_cfg
95
+ return cfg
96
+
97
+ def merge_from_file(self, cfg_filename, allow_unsafe = False):
98
+ """
99
+ Merge configs from a given yaml file.
100
+
101
+ Args:
102
+ cfg_filename: the file name of the yaml config.
103
+ allow_unsafe: whether to allow loading the config file with
104
+ `yaml.unsafe_load`.
105
+ """
106
+ loaded_cfg = CfgNode.load_yaml_with_base(
107
+ cfg_filename, allow_unsafe=allow_unsafe
108
+ )
109
+ loaded_cfg = type(self)(loaded_cfg)
110
+ self.merge_from_other_cfg(loaded_cfg)
111
+
112
+ # Forward the following calls to base, but with a check on the BASE_KEY.
113
+ def merge_from_other_cfg(self, cfg_other):
114
+ """
115
+ Args:
116
+ cfg_other (CfgNode): configs to merge from.
117
+ """
118
+ assert (
119
+ BASE_KEY not in cfg_other
120
+ ), "The reserved key '{}' can only be used in files!".format(BASE_KEY)
121
+ return super().merge_from_other_cfg(cfg_other)
122
+
123
+ def merge_from_list(self, cfg_list):
124
+ """
125
+ Args:
126
+ cfg_list (list): list of configs to merge from.
127
+ """
128
+ keys = set(cfg_list[0::2])
129
+ assert (
130
+ BASE_KEY not in keys
131
+ ), "The reserved key '{}' can only be used in files!".format(BASE_KEY)
132
+ return super().merge_from_list(cfg_list)
133
+
134
+ def __setattr__(self, name, val):
135
+ if name.startswith("COMPUTED_"):
136
+ if name in self:
137
+ old_val = self[name]
138
+ if old_val == val:
139
+ return
140
+ raise KeyError(
141
+ "Computed attributed '{}' already exists "
142
+ "with a different value! old={}, new={}.".format(
143
+ name, old_val, val
144
+ )
145
+ )
146
+ self[name] = val
147
+ else:
148
+ super().__setattr__(name, val)
149
+
150
+
151
+ if __name__ == '__main__':
152
+ cfg = CfgNode.load_yaml_with_base('configs/updown_long.yml')
153
+ print(cfg)
captioning/utils/dist_utils.py ADDED
@@ -0,0 +1,305 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
2
+ """
3
+ This file contains primitives for multi-gpu communication.
4
+ This is useful when doing distributed training.
5
+ """
6
+
7
+ import functools
8
+ import logging
9
+ import numpy as np
10
+ import pickle
11
+ import torch
12
+ import torch.distributed as dist
13
+
14
+ import torch
15
+
16
+ _LOCAL_PROCESS_GROUP = None
17
+ """
18
+ A torch process group which only includes processes that on the same machine as the current process.
19
+ This variable is set when processes are spawned by `launch()` in "engine/launch.py".
20
+ """
21
+
22
+
23
+ def get_world_size() -> int:
24
+ if not dist.is_available():
25
+ return 1
26
+ if not dist.is_initialized():
27
+ return 1
28
+ return dist.get_world_size()
29
+
30
+
31
+ def get_rank() -> int:
32
+ if not dist.is_available():
33
+ return 0
34
+ if not dist.is_initialized():
35
+ return 0
36
+ return dist.get_rank()
37
+
38
+
39
+ def get_local_rank() -> int:
40
+ """
41
+ Returns:
42
+ The rank of the current process within the local (per-machine) process group.
43
+ """
44
+ if not dist.is_available():
45
+ return 0
46
+ if not dist.is_initialized():
47
+ return 0
48
+ assert _LOCAL_PROCESS_GROUP is not None
49
+ return dist.get_rank(group=_LOCAL_PROCESS_GROUP)
50
+
51
+
52
+ def get_local_size() -> int:
53
+ """
54
+ Returns:
55
+ The size of the per-machine process group,
56
+ i.e. the number of processes per machine.
57
+ """
58
+ if not dist.is_available():
59
+ return 1
60
+ if not dist.is_initialized():
61
+ return 1
62
+ return dist.get_world_size(group=_LOCAL_PROCESS_GROUP)
63
+
64
+
65
+ def is_main_process() -> bool:
66
+ return get_rank() == 0
67
+
68
+
69
+ def synchronize():
70
+ """
71
+ Helper function to synchronize (barrier) among all processes when
72
+ using distributed training
73
+ """
74
+ if not dist.is_available():
75
+ return
76
+ if not dist.is_initialized():
77
+ return
78
+ world_size = dist.get_world_size()
79
+ if world_size == 1:
80
+ return
81
+ dist.barrier()
82
+
83
+
84
+ @functools.lru_cache()
85
+ def _get_global_gloo_group():
86
+ """
87
+ Return a process group based on gloo backend, containing all the ranks
88
+ The result is cached.
89
+ """
90
+ if dist.get_backend() == "nccl":
91
+ return dist.new_group(backend="gloo")
92
+ else:
93
+ return dist.group.WORLD
94
+
95
+
96
+ def _serialize_to_tensor(data, group):
97
+ backend = dist.get_backend(group)
98
+ assert backend in ["gloo", "nccl"]
99
+ device = torch.device("cpu" if backend == "gloo" else "cuda")
100
+
101
+ buffer = pickle.dumps(data)
102
+ if len(buffer) > 1024 ** 3:
103
+ logger = logging.getLogger(__name__)
104
+ logger.warning(
105
+ "Rank {} trying to all-gather {:.2f} GB of data on device {}".format(
106
+ get_rank(), len(buffer) / (1024 ** 3), device
107
+ )
108
+ )
109
+ storage = torch.ByteStorage.from_buffer(buffer)
110
+ tensor = torch.ByteTensor(storage).to(device=device)
111
+ return tensor
112
+
113
+
114
+ def _pad_to_largest_tensor(tensor, group):
115
+ """
116
+ Returns:
117
+ list[int]: size of the tensor, on each rank
118
+ Tensor: padded tensor that has the max size
119
+ """
120
+ world_size = dist.get_world_size(group=group)
121
+ assert (
122
+ world_size >= 1
123
+ ), "comm.gather/all_gather must be called from ranks within the given group!"
124
+ local_size = torch.tensor(
125
+ [tensor.numel()], dtype=torch.int64, device=tensor.device)
126
+ size_list = [
127
+ torch.zeros([1], dtype=torch.int64, device=tensor.device)
128
+ for _ in range(world_size)
129
+ ]
130
+ dist.all_gather(size_list, local_size, group=group)
131
+ size_list = [int(size.item()) for size in size_list]
132
+
133
+ max_size = max(size_list)
134
+
135
+ # we pad the tensor because torch all_gather does not support
136
+ # gathering tensors of different shapes
137
+ if local_size != max_size:
138
+ padding = torch.zeros(
139
+ (max_size - local_size,), dtype=torch.uint8, device=tensor.device
140
+ )
141
+ tensor = torch.cat((tensor, padding), dim=0)
142
+ return size_list, tensor
143
+
144
+
145
+ def all_gather(data, group=None):
146
+ """
147
+ Run all_gather on arbitrary picklable data (not necessarily tensors).
148
+ Args:
149
+ data: any picklable object
150
+ group: a torch process group. By default, will use a group which
151
+ contains all ranks on gloo backend.
152
+ Returns:
153
+ list[data]: list of data gathered from each rank
154
+ """
155
+ if get_world_size() == 1:
156
+ return [data]
157
+ if group is None:
158
+ group = _get_global_gloo_group()
159
+ if dist.get_world_size(group) == 1:
160
+ return [data]
161
+
162
+ tensor = _serialize_to_tensor(data, group)
163
+
164
+ size_list, tensor = _pad_to_largest_tensor(tensor, group)
165
+ max_size = max(size_list)
166
+
167
+ # receiving Tensor from all ranks
168
+ tensor_list = [
169
+ torch.empty((max_size,), dtype=torch.uint8, device=tensor.device)
170
+ for _ in size_list
171
+ ]
172
+ dist.all_gather(tensor_list, tensor, group=group)
173
+
174
+ data_list = []
175
+ for size, tensor in zip(size_list, tensor_list):
176
+ buffer = tensor.cpu().numpy().tobytes()[:size]
177
+ data_list.append(pickle.loads(buffer))
178
+
179
+ return data_list
180
+
181
+
182
+ def gather(data, dst=0, group=None):
183
+ """
184
+ Run gather on arbitrary picklable data (not necessarily tensors).
185
+ Args:
186
+ data: any picklable object
187
+ dst (int): destination rank
188
+ group: a torch process group. By default, will use a group which
189
+ contains all ranks on gloo backend.
190
+ Returns:
191
+ list[data]: on dst, a list of data gathered from each rank. Otherwise,
192
+ an empty list.
193
+ """
194
+ if get_world_size() == 1:
195
+ return [data]
196
+ if group is None:
197
+ group = _get_global_gloo_group()
198
+ if dist.get_world_size(group=group) == 1:
199
+ return [data]
200
+ rank = dist.get_rank(group=group)
201
+
202
+ tensor = _serialize_to_tensor(data, group)
203
+ size_list, tensor = _pad_to_largest_tensor(tensor, group)
204
+
205
+ # receiving Tensor from all ranks
206
+ if rank == dst:
207
+ max_size = max(size_list)
208
+ tensor_list = [
209
+ torch.empty((max_size,), dtype=torch.uint8, device=tensor.device)
210
+ for _ in size_list
211
+ ]
212
+ dist.gather(tensor, tensor_list, dst=dst, group=group)
213
+
214
+ data_list = []
215
+ for size, tensor in zip(size_list, tensor_list):
216
+ buffer = tensor.cpu().numpy().tobytes()[:size]
217
+ data_list.append(pickle.loads(buffer))
218
+ return data_list
219
+ else:
220
+ dist.gather(tensor, [], dst=dst, group=group)
221
+ return []
222
+
223
+
224
+ def shared_random_seed():
225
+ """
226
+ Returns:
227
+ int: a random number that is the same across all workers.
228
+ If workers need a shared RNG, they can use this shared seed to
229
+ create one.
230
+ All workers must call this function, otherwise it will deadlock.
231
+ """
232
+ ints = np.random.randint(2 ** 31)
233
+ all_ints = all_gather(ints)
234
+ return all_ints[0]
235
+
236
+
237
+ # def reduce_dict(input_dict, average=True):
238
+ # """
239
+ # Reduce the values in the dictionary from all processes so that process with rank
240
+ # 0 has the reduced results.
241
+ # Args:
242
+ # input_dict (dict): inputs to be reduced. All the values must be scalar CUDA Tensor.
243
+ # average (bool): whether to do average or sum
244
+ # Returns:
245
+ # a dict with the same keys as input_dict, after reduction.
246
+ # """
247
+ # world_size = get_world_size()
248
+ # if world_size < 2:
249
+ # return input_dict
250
+ # with torch.no_grad():
251
+ # names = []
252
+ # values = []
253
+ # # sort the keys so that they are consistent across processes
254
+ # for k in sorted(input_dict.keys()):
255
+ # names.append(k)
256
+ # values.append(input_dict[k])
257
+ # values = torch.stack(values, dim=0)
258
+ # dist.reduce(values, dst=0)
259
+ # if dist.get_rank() == 0 and average:
260
+ # # only main process gets accumulated, so only divide by
261
+ # # world_size in this case
262
+ # values /= world_size
263
+ # reduced_dict = {k: v for k, v in zip(names, values)}
264
+ # return reduced_dict
265
+
266
+
267
+ def reduce_dict(input_dict, average=True):
268
+ """
269
+ Reduce the values in the dictionary from all processes so that process with rank
270
+ 0 has the reduced results.
271
+ Args:
272
+ input_dict (dict): inputs to be reduced. (values not necessarily tensors).
273
+ average (bool): whether to do average or sum
274
+ Returns:
275
+ a dict with the same keys as input_dict, after reduction.
276
+ """
277
+
278
+ world_size = get_world_size()
279
+ if world_size < 2:
280
+ return input_dict
281
+
282
+ with torch.no_grad():
283
+
284
+ # Convert to CUDA Tensor for dist.reduce()
285
+ input_dict_cuda_vals = {}
286
+ for k, v in input_dict.items():
287
+ if type(v) == torch.Tensor:
288
+ input_dict_cuda_vals[k] = v.to('cuda')
289
+ else:
290
+ input_dict_cuda_vals[k] = torch.tensor(v, device='cuda')
291
+
292
+ names = []
293
+ values = []
294
+ for k, v in sorted(input_dict_cuda_vals.items()):
295
+ names.append(k)
296
+ values.append(v)
297
+ values = torch.stack(values, dim=0)
298
+ dist.reduce(values, dst=0) # reduce to gpu 0
299
+
300
+ if dist.get_rank() == 0 and average:
301
+ # only main process gets accumulated, so only divide by
302
+ # world_size in this case
303
+ values /= world_size
304
+ reduced_dict = {k: v for k, v in zip(names, values)}
305
+ return reduced_dict
captioning/utils/div_utils.py ADDED
@@ -0,0 +1,38 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from random import uniform
2
+ import numpy as np
3
+ from collections import OrderedDict, defaultdict
4
+ from itertools import tee
5
+ import time
6
+
7
+ # -----------------------------------------------
8
+ def find_ngrams(input_list, n):
9
+ return zip(*[input_list[i:] for i in range(n)])
10
+
11
+ def compute_div_n(caps,n=1):
12
+ aggr_div = []
13
+ for k in caps:
14
+ all_ngrams = set()
15
+ lenT = 0.
16
+ for c in caps[k]:
17
+ tkns = c.split()
18
+ lenT += len(tkns)
19
+ ng = find_ngrams(tkns, n)
20
+ all_ngrams.update(ng)
21
+ aggr_div.append(float(len(all_ngrams))/ (1e-6 + float(lenT)))
22
+ return np.array(aggr_div).mean(), np.array(aggr_div)
23
+
24
+ def compute_global_div_n(caps,n=1):
25
+ aggr_div = []
26
+ all_ngrams = set()
27
+ lenT = 0.
28
+ for k in caps:
29
+ for c in caps[k]:
30
+ tkns = c.split()
31
+ lenT += len(tkns)
32
+ ng = find_ngrams(tkns, n)
33
+ all_ngrams.update(ng)
34
+ if n == 1:
35
+ aggr_div.append(float(len(all_ngrams)))
36
+ else:
37
+ aggr_div.append(float(len(all_ngrams))/ (1e-6 + float(lenT)))
38
+ return aggr_div[0], np.repeat(np.array(aggr_div),len(caps))
captioning/utils/eval_multi.py ADDED
@@ -0,0 +1,218 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import absolute_import
2
+ from __future__ import division
3
+ from __future__ import print_function
4
+
5
+ import torch
6
+ import torch.nn as nn
7
+
8
+ import numpy as np
9
+ import json
10
+ from json import encoder
11
+ import random
12
+ import string
13
+ import time
14
+ import os
15
+ import sys
16
+ from . import misc as utils
17
+ from eval_utils import getCOCO
18
+
19
+ from .div_utils import compute_div_n, compute_global_div_n
20
+
21
+ import sys
22
+ try:
23
+ sys.path.append("coco-caption")
24
+ annFile = 'coco-caption/annotations/captions_val2014.json'
25
+ from pycocotools.coco import COCO
26
+ from pycocoevalcap.eval import COCOEvalCap
27
+ from pycocoevalcap.eval_spice import COCOEvalCapSpice
28
+ from pycocoevalcap.tokenizer.ptbtokenizer import PTBTokenizer
29
+ from pycocoevalcap.bleu.bleu import Bleu
30
+ sys.path.append("cider")
31
+ from pyciderevalcap.cider.cider import Cider
32
+ except:
33
+ print('Warning: requirements for eval_multi not satisfied')
34
+
35
+
36
+ def eval_allspice(dataset, preds_n, model_id, split):
37
+ coco = getCOCO(dataset)
38
+ valids = coco.getImgIds()
39
+
40
+ capsById = {}
41
+ for d in preds_n:
42
+ capsById[d['image_id']] = capsById.get(d['image_id'], []) + [d]
43
+
44
+ # filter results to only those in MSCOCO validation set (will be about a third)
45
+ preds_filt_n = [p for p in preds_n if p['image_id'] in valids]
46
+ print('using %d/%d predictions_n' % (len(preds_filt_n), len(preds_n)))
47
+ cache_path_n = os.path.join('eval_results/', model_id + '_' + split + '_n.json')
48
+ json.dump(preds_filt_n, open(cache_path_n, 'w')) # serialize to temporary json file. Sigh, COCO API...
49
+
50
+ # Eval AllSPICE
51
+ cocoRes_n = coco.loadRes(cache_path_n)
52
+ cocoEvalAllSPICE = COCOEvalCapSpice(coco, cocoRes_n)
53
+ cocoEvalAllSPICE.params['image_id'] = cocoRes_n.getImgIds()
54
+ cocoEvalAllSPICE.evaluate()
55
+
56
+ out = {}
57
+ for metric, score in cocoEvalAllSPICE.eval.items():
58
+ out['All'+metric] = score
59
+
60
+ imgToEvalAllSPICE = cocoEvalAllSPICE.imgToEval
61
+ # collect SPICE_sub_score
62
+ for k in list(imgToEvalAllSPICE.values())[0]['SPICE'].keys():
63
+ if k != 'All':
64
+ out['AllSPICE_'+k] = np.array([v['SPICE'][k]['f'] for v in imgToEvalAllSPICE.values()])
65
+ out['AllSPICE_'+k] = (out['AllSPICE_'+k][out['AllSPICE_'+k]==out['AllSPICE_'+k]]).mean()
66
+ for p in preds_filt_n:
67
+ image_id, caption = p['image_id'], p['caption']
68
+ imgToEvalAllSPICE[image_id]['caption'] = capsById[image_id]
69
+ return {'overall': out, 'imgToEvalAllSPICE': imgToEvalAllSPICE}
70
+
71
+ def eval_oracle(dataset, preds_n, model_id, split):
72
+ cache_path = os.path.join('eval_results/', model_id + '_' + split + '_n.json')
73
+
74
+ coco = getCOCO(dataset)
75
+ valids = coco.getImgIds()
76
+
77
+ capsById = {}
78
+ for d in preds_n:
79
+ capsById[d['image_id']] = capsById.get(d['image_id'], []) + [d]
80
+
81
+ sample_n = capsById[list(capsById.keys())[0]]
82
+ for i in range(len(capsById[list(capsById.keys())[0]])):
83
+ preds = [_[i] for _ in capsById.values()]
84
+
85
+ json.dump(preds, open(cache_path, 'w')) # serialize to temporary json file. Sigh, COCO API...
86
+
87
+ cocoRes = coco.loadRes(cache_path)
88
+ cocoEval = COCOEvalCap(coco, cocoRes)
89
+ cocoEval.params['image_id'] = cocoRes.getImgIds()
90
+ cocoEval.evaluate()
91
+
92
+ imgToEval = cocoEval.imgToEval
93
+ for img_id in capsById.keys():
94
+ tmp = imgToEval[img_id]
95
+ for k in tmp['SPICE'].keys():
96
+ if k != 'All':
97
+ tmp['SPICE_'+k] = tmp['SPICE'][k]['f']
98
+ if tmp['SPICE_'+k] != tmp['SPICE_'+k]: # nan
99
+ tmp['SPICE_'+k] = -100
100
+ tmp['SPICE'] = tmp['SPICE']['All']['f']
101
+ if tmp['SPICE'] != tmp['SPICE']: tmp['SPICE'] = -100
102
+ capsById[img_id][i]['scores'] = imgToEval[img_id]
103
+
104
+ out = {'overall': {}, 'ImgToEval': {}}
105
+ for img_id in capsById.keys():
106
+ out['ImgToEval'][img_id] = {}
107
+ for metric in capsById[img_id][0]['scores'].keys():
108
+ if metric == 'image_id': continue
109
+ out['ImgToEval'][img_id]['oracle_'+metric] = max([_['scores'][metric] for _ in capsById[img_id]])
110
+ out['ImgToEval'][img_id]['avg_'+metric] = sum([_['scores'][metric] for _ in capsById[img_id]]) / len(capsById[img_id])
111
+ out['ImgToEval'][img_id]['captions'] = capsById[img_id]
112
+ for metric in list(out['ImgToEval'].values())[0].keys():
113
+ if metric == 'captions':
114
+ continue
115
+ tmp = np.array([_[metric] for _ in out['ImgToEval'].values()])
116
+ tmp = tmp[tmp!=-100]
117
+ out['overall'][metric] = tmp.mean()
118
+
119
+ return out
120
+
121
+ def eval_div_stats(dataset, preds_n, model_id, split):
122
+ tokenizer = PTBTokenizer()
123
+
124
+ capsById = {}
125
+ for i, d in enumerate(preds_n):
126
+ d['id'] = i
127
+ capsById[d['image_id']] = capsById.get(d['image_id'], []) + [d]
128
+
129
+ n_caps_perimg = len(capsById[list(capsById.keys())[0]])
130
+ print(n_caps_perimg)
131
+ _capsById = capsById # save the untokenized version
132
+ capsById = tokenizer.tokenize(capsById)
133
+
134
+ div_1, adiv_1 = compute_div_n(capsById,1)
135
+ div_2, adiv_2 = compute_div_n(capsById,2)
136
+
137
+ globdiv_1, _= compute_global_div_n(capsById,1)
138
+
139
+ print('Diversity Statistics are as follows: \n Div1: %.2f, Div2: %.2f, gDiv1: %d\n'%(div_1,div_2, globdiv_1))
140
+
141
+ # compute mbleu
142
+ scorer = Bleu(4)
143
+ all_scrs = []
144
+ scrperimg = np.zeros((n_caps_perimg, len(capsById)))
145
+
146
+ for i in range(n_caps_perimg):
147
+ tempRefsById = {}
148
+ candsById = {}
149
+ for k in capsById:
150
+ tempRefsById[k] = capsById[k][:i] + capsById[k][i+1:]
151
+ candsById[k] = [capsById[k][i]]
152
+
153
+ score, scores = scorer.compute_score(tempRefsById, candsById)
154
+ all_scrs.append(score)
155
+ scrperimg[i,:] = scores[1]
156
+
157
+ all_scrs = np.array(all_scrs)
158
+
159
+ out = {}
160
+ out['overall'] = {'Div1': div_1, 'Div2': div_2, 'gDiv1': globdiv_1}
161
+ for k, score in zip(range(4), all_scrs.mean(axis=0).tolist()):
162
+ out['overall'].update({'mBLeu_%d'%(k+1): score})
163
+ imgToEval = {}
164
+ for i,imgid in enumerate(capsById.keys()):
165
+ imgToEval[imgid] = {'mBleu_2' : scrperimg[:,i].mean()}
166
+ imgToEval[imgid]['individuals'] = []
167
+ for j, d in enumerate(_capsById[imgid]):
168
+ imgToEval[imgid]['individuals'].append(preds_n[d['id']])
169
+ imgToEval[imgid]['individuals'][-1]['mBleu_2'] = scrperimg[j,i]
170
+ out['ImgToEval'] = imgToEval
171
+
172
+ print('Mean mutual Bleu scores on this set is:\nmBLeu_1, mBLeu_2, mBLeu_3, mBLeu_4')
173
+ print(all_scrs.mean(axis=0))
174
+
175
+ return out
176
+
177
+ def eval_self_cider(dataset, preds_n, model_id, split):
178
+ cache_path = os.path.join('eval_results/', model_id + '_' + split + '_n.json')
179
+
180
+ coco = getCOCO(dataset)
181
+ valids = coco.getImgIds()
182
+
183
+ # Get Cider_scorer
184
+ Cider_scorer = Cider(df='corpus')
185
+
186
+ tokenizer = PTBTokenizer()
187
+ gts = {}
188
+ for imgId in valids:
189
+ gts[imgId] = coco.imgToAnns[imgId]
190
+ gts = tokenizer.tokenize(gts)
191
+
192
+ for imgId in valids:
193
+ Cider_scorer.cider_scorer += (None, gts[imgId])
194
+ Cider_scorer.cider_scorer.compute_doc_freq()
195
+ Cider_scorer.cider_scorer.ref_len = np.log(float(len(Cider_scorer.cider_scorer.crefs)))
196
+
197
+ # Prepare captions
198
+ capsById = {}
199
+ for d in preds_n:
200
+ capsById[d['image_id']] = capsById.get(d['image_id'], []) + [d]
201
+
202
+ capsById = tokenizer.tokenize(capsById)
203
+ imgIds = list(capsById.keys())
204
+ scores = Cider_scorer.my_self_cider([capsById[_] for _ in imgIds])
205
+
206
+ def get_div(eigvals):
207
+ eigvals = np.clip(eigvals, 0, None)
208
+ return -np.log(np.sqrt(eigvals[-1]) / (np.sqrt(eigvals).sum())) / np.log(len(eigvals))
209
+ sc_scores = [get_div(np.linalg.eigvalsh(_/10)) for _ in scores]
210
+ score = np.mean(np.array(sc_scores))
211
+
212
+ imgToEval = {}
213
+ for i, image_id in enumerate(imgIds):
214
+ imgToEval[image_id] = {'self_cider': sc_scores[i], 'self_cider_mat': scores[i].tolist()}
215
+ return {'overall': {'self_cider': score}, 'imgToEval': imgToEval}
216
+
217
+
218
+ return score
captioning/utils/eval_utils.py ADDED
@@ -0,0 +1,281 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import absolute_import
2
+ from __future__ import division
3
+ from __future__ import print_function
4
+
5
+ import torch
6
+ import torch.nn as nn
7
+ import torch.nn.functional as F
8
+
9
+ import numpy as np
10
+ import json
11
+ from json import encoder
12
+ import random
13
+ import string
14
+ import time
15
+ import os
16
+ import sys
17
+ from . import misc as utils
18
+
19
+ # load coco-caption if available
20
+ try:
21
+ sys.path.append("coco-caption")
22
+ from pycocotools.coco import COCO
23
+ from pycocoevalcap.eval import COCOEvalCap
24
+ except:
25
+ print('Warning: coco-caption not available')
26
+
27
+ bad_endings = ['a','an','the','in','for','at','of','with','before','after','on','upon','near','to','is','are','am']
28
+ bad_endings += ['the']
29
+
30
+
31
+ def count_bad(sen):
32
+ sen = sen.split(' ')
33
+ if sen[-1] in bad_endings:
34
+ return 1
35
+ else:
36
+ return 0
37
+
38
+
39
+ def getCOCO(dataset):
40
+ if 'coco' in dataset:
41
+ annFile = 'coco-caption/annotations/captions_val2014.json'
42
+ elif 'flickr30k' in dataset or 'f30k' in dataset:
43
+ annFile = 'data/f30k_captions4eval.json'
44
+ return COCO(annFile)
45
+
46
+
47
+ def language_eval(dataset, preds, preds_n, eval_kwargs, split):
48
+ model_id = eval_kwargs['id']
49
+ eval_oracle = eval_kwargs.get('eval_oracle', 0)
50
+
51
+ # create output dictionary
52
+ out = {}
53
+
54
+ if len(preds_n) > 0:
55
+ # vocab size and novel sentences
56
+ if 'coco' in dataset:
57
+ dataset_file = 'data/dataset_coco.json'
58
+ elif 'flickr30k' in dataset or 'f30k' in dataset:
59
+ dataset_file = 'data/dataset_flickr30k.json'
60
+ training_sentences = set([' '.join(__['tokens']) for _ in json.load(open(dataset_file))['images'] if not _['split'] in ['val', 'test'] for __ in _['sentences']])
61
+ generated_sentences = set([_['caption'] for _ in preds_n])
62
+ novels = generated_sentences - training_sentences
63
+ out['novel_sentences'] = float(len(novels)) / len(preds_n)
64
+ tmp = [_.split() for _ in generated_sentences]
65
+ words = []
66
+ for _ in tmp:
67
+ words += _
68
+ out['vocab_size'] = len(set(words))
69
+
70
+ # encoder.FLOAT_REPR = lambda o: format(o, '.3f')
71
+
72
+ cache_path = os.path.join('eval_results/', '.cache_'+ model_id + '_' + split + '.json')
73
+
74
+ coco = getCOCO(dataset)
75
+ valids = coco.getImgIds()
76
+
77
+ # filter results to only those in MSCOCO validation set
78
+ preds_filt = [p for p in preds if p['image_id'] in valids]
79
+ mean_perplexity = sum([_['perplexity'] for _ in preds_filt]) / len(preds_filt)
80
+ mean_entropy = sum([_['entropy'] for _ in preds_filt]) / len(preds_filt)
81
+ print('using %d/%d predictions' % (len(preds_filt), len(preds)))
82
+ json.dump(preds_filt, open(cache_path, 'w')) # serialize to temporary json file. Sigh, COCO API...
83
+
84
+ cocoRes = coco.loadRes(cache_path)
85
+ cocoEval = COCOEvalCap(coco, cocoRes)
86
+ cocoEval.params['image_id'] = cocoRes.getImgIds()
87
+ cocoEval.evaluate()
88
+
89
+ for metric, score in cocoEval.eval.items():
90
+ out[metric] = score
91
+ # Add mean perplexity
92
+ out['perplexity'] = mean_perplexity
93
+ out['entropy'] = mean_entropy
94
+
95
+ imgToEval = cocoEval.imgToEval
96
+ for k in list(imgToEval.values())[0]['SPICE'].keys():
97
+ if k != 'All':
98
+ out['SPICE_'+k] = np.array([v['SPICE'][k]['f'] for v in imgToEval.values()])
99
+ out['SPICE_'+k] = (out['SPICE_'+k][out['SPICE_'+k]==out['SPICE_'+k]]).mean()
100
+ for p in preds_filt:
101
+ image_id, caption = p['image_id'], p['caption']
102
+ imgToEval[image_id]['caption'] = caption
103
+
104
+ if len(preds_n) > 0:
105
+ from . import eval_multi
106
+ cache_path_n = os.path.join('eval_results/', '.cache_'+ model_id + '_' + split + '_n.json')
107
+ allspice = eval_multi.eval_allspice(dataset, preds_n, model_id, split)
108
+ out.update(allspice['overall'])
109
+ div_stats = eval_multi.eval_div_stats(dataset, preds_n, model_id, split)
110
+ out.update(div_stats['overall'])
111
+ if eval_oracle:
112
+ oracle = eval_multi.eval_oracle(dataset, preds_n, model_id, split)
113
+ out.update(oracle['overall'])
114
+ else:
115
+ oracle = None
116
+ self_cider = eval_multi.eval_self_cider(dataset, preds_n, model_id, split)
117
+ out.update(self_cider['overall'])
118
+ with open(cache_path_n, 'w') as outfile:
119
+ json.dump({'allspice': allspice, 'div_stats': div_stats, 'oracle': oracle, 'self_cider': self_cider}, outfile)
120
+
121
+ out['bad_count_rate'] = sum([count_bad(_['caption']) for _ in preds_filt]) / float(len(preds_filt))
122
+ outfile_path = os.path.join('eval_results/', model_id + '_' + split + '.json')
123
+ with open(outfile_path, 'w') as outfile:
124
+ json.dump({'overall': out, 'imgToEval': imgToEval}, outfile)
125
+
126
+ return out
127
+
128
+ def eval_split(model, crit, loader, eval_kwargs={}):
129
+ verbose = eval_kwargs.get('verbose', True)
130
+ verbose_beam = eval_kwargs.get('verbose_beam', 0)
131
+ verbose_loss = eval_kwargs.get('verbose_loss', 1)
132
+ num_images = eval_kwargs.get('num_images', eval_kwargs.get('val_images_use', -1))
133
+ split = eval_kwargs.get('split', 'val')
134
+ lang_eval = eval_kwargs.get('language_eval', 0)
135
+ dataset = eval_kwargs.get('dataset', 'coco')
136
+ beam_size = eval_kwargs.get('beam_size', 1)
137
+ sample_n = eval_kwargs.get('sample_n', 1)
138
+ remove_bad_endings = eval_kwargs.get('remove_bad_endings', 0)
139
+ os.environ["REMOVE_BAD_ENDINGS"] = str(remove_bad_endings) # Use this nasty way to make other code clean since it's a global configuration
140
+ device = eval_kwargs.get('device', 'cuda')
141
+
142
+ # Make sure in the evaluation mode
143
+ model.eval()
144
+
145
+ loader.reset_iterator(split)
146
+
147
+ n = 0
148
+ loss = 0
149
+ loss_sum = 0
150
+ loss_evals = 1e-8
151
+ predictions = []
152
+ n_predictions = [] # when sample_n > 1
153
+ while True:
154
+ data = loader.get_batch(split)
155
+ n = n + len(data['infos'])
156
+
157
+ tmp = [data['fc_feats'], data['att_feats'], data['labels'], data['masks'], data['att_masks']]
158
+ tmp = [_.to(device) if _ is not None else _ for _ in tmp]
159
+ fc_feats, att_feats, labels, masks, att_masks = tmp
160
+ if labels is not None and verbose_loss:
161
+ # forward the model to get loss
162
+ with torch.no_grad():
163
+ loss = crit(model(fc_feats, att_feats, labels[..., :-1], att_masks), labels[..., 1:], masks[..., 1:]).item()
164
+ loss_sum = loss_sum + loss
165
+ loss_evals = loss_evals + 1
166
+
167
+ # forward the model to also get generated samples for each image
168
+ with torch.no_grad():
169
+ tmp_eval_kwargs = eval_kwargs.copy()
170
+ tmp_eval_kwargs.update({'sample_n': 1})
171
+ seq, seq_logprobs = model(fc_feats, att_feats, att_masks, opt=tmp_eval_kwargs, mode='sample')
172
+ seq = seq.data
173
+ entropy = - (F.softmax(seq_logprobs, dim=2) * seq_logprobs).sum(2).sum(1) / ((seq>0).to(seq_logprobs).sum(1)+1)
174
+ perplexity = - seq_logprobs.gather(2, seq.unsqueeze(2)).squeeze(2).sum(1) / ((seq>0).to(seq_logprobs).sum(1)+1)
175
+
176
+ # Print beam search
177
+ if beam_size > 1 and verbose_beam:
178
+ for i in range(fc_feats.shape[0]):
179
+ print('\n'.join([utils.decode_sequence(model.vocab, _['seq'].unsqueeze(0))[0] for _ in model.done_beams[i]]))
180
+ print('--' * 10)
181
+ sents = utils.decode_sequence(model.vocab, seq)
182
+
183
+ for k, sent in enumerate(sents):
184
+ entry = {'image_id': data['infos'][k]['id'], 'caption': sent, 'perplexity': perplexity[k].item(), 'entropy': entropy[k].item()}
185
+ if eval_kwargs.get('dump_path', 0) == 1:
186
+ entry['file_name'] = data['infos'][k]['file_path']
187
+ predictions.append(entry)
188
+ if eval_kwargs.get('dump_images', 0) == 1:
189
+ # dump the raw image to vis/ folder
190
+ cmd = 'cp "' + os.path.join(eval_kwargs['image_root'], data['infos'][k]['file_path']) + '" vis/imgs/img' + str(len(predictions)) + '.jpg' # bit gross
191
+ print(cmd)
192
+ os.system(cmd)
193
+
194
+ if verbose:
195
+ print('image %s: %s' %(entry['image_id'], entry['caption']))
196
+
197
+ if sample_n > 1:
198
+ eval_split_n(model, n_predictions, [fc_feats, att_feats, att_masks, data], eval_kwargs)
199
+
200
+ # ix0 = data['bounds']['it_pos_now']
201
+ ix1 = data['bounds']['it_max']
202
+ if num_images != -1:
203
+ ix1 = min(ix1, num_images)
204
+ else:
205
+ num_images = ix1
206
+ for i in range(n - ix1):
207
+ predictions.pop()
208
+
209
+ if verbose:
210
+ print('evaluating validation preformance... %d/%d (%f)' %(n, ix1, loss))
211
+
212
+ if num_images >= 0 and n >= num_images:
213
+ break
214
+
215
+ lang_stats = None
216
+ if len(n_predictions) > 0 and 'perplexity' in n_predictions[0]:
217
+ n_predictions = sorted(n_predictions, key=lambda x: x['perplexity'])
218
+ if not os.path.isdir('eval_results'):
219
+ os.mkdir('eval_results')
220
+ torch.save((predictions, n_predictions), os.path.join('eval_results/', '.saved_pred_'+ eval_kwargs['id'] + '_' + split + '.pth'))
221
+ if lang_eval == 1:
222
+ lang_stats = language_eval(dataset, predictions, n_predictions, eval_kwargs, split)
223
+
224
+ # Switch back to training mode
225
+ model.train()
226
+ return loss_sum/loss_evals, predictions, lang_stats
227
+
228
+
229
+ # Only run when sample_n > 0
230
+ def eval_split_n(model, n_predictions, input_data, eval_kwargs={}):
231
+ verbose = eval_kwargs.get('verbose', True)
232
+ beam_size = eval_kwargs.get('beam_size', 1)
233
+ sample_n = eval_kwargs.get('sample_n', 1)
234
+ sample_n_method = eval_kwargs.get('sample_n_method', 'sample')
235
+
236
+ fc_feats, att_feats, att_masks, data = input_data
237
+
238
+ tmp_eval_kwargs = eval_kwargs.copy()
239
+ if sample_n_method == 'bs':
240
+ # case 1 sample_n == beam size
241
+ tmp_eval_kwargs.update({'sample_n': 1, 'beam_size': sample_n, 'group_size': 1}) # randomness from softmax
242
+ with torch.no_grad():
243
+ model(fc_feats, att_feats, att_masks, opt=tmp_eval_kwargs, mode='sample')
244
+ for k in range(fc_feats.shape[0]):
245
+ _sents = utils.decode_sequence(model.vocab, torch.stack([model.done_beams[k][_]['seq'] for _ in range(sample_n)]))
246
+ for sent in _sents:
247
+ entry = {'image_id': data['infos'][k]['id'], 'caption': sent}
248
+ n_predictions.append(entry)
249
+ # case 2 sample / gumbel / topk sampling/ nucleus sampling
250
+ elif sample_n_method == 'sample' or \
251
+ sample_n_method == 'gumbel' or \
252
+ sample_n_method.startswith('top'):
253
+ tmp_eval_kwargs.update({'sample_n': sample_n, 'sample_method': sample_n_method, 'beam_size': 1}) # randomness from sample
254
+ with torch.no_grad():
255
+ _seq, _sampleLogprobs = model(fc_feats, att_feats, att_masks, opt=tmp_eval_kwargs, mode='sample')
256
+ _sents = utils.decode_sequence(model.vocab, _seq)
257
+ _perplexity = - _sampleLogprobs.gather(2, _seq.unsqueeze(2)).squeeze(2).sum(1) / ((_seq>0).to(_sampleLogprobs).sum(1)+1)
258
+ for k, sent in enumerate(_sents):
259
+ entry = {'image_id': data['infos'][k // sample_n]['id'], 'caption': sent, 'perplexity': _perplexity[k].item()}
260
+ n_predictions.append(entry)
261
+ elif sample_n_method == 'dbs':
262
+ # Use diverse beam search
263
+ tmp_eval_kwargs.update({'beam_size': sample_n * beam_size, 'group_size': sample_n}) # randomness from softmax
264
+ with torch.no_grad():
265
+ model(fc_feats, att_feats, att_masks, opt=tmp_eval_kwargs, mode='sample')
266
+ for k in range(loader.batch_size):
267
+ _sents = utils.decode_sequence(model.vocab, torch.stack([model.done_beams[k][_]['seq'] for _ in range(0, sample_n*beam_size, beam_size)]))
268
+ for sent in _sents:
269
+ entry = {'image_id': data['infos'][k]['id'], 'caption': sent}
270
+ n_predictions.append(entry)
271
+ else:
272
+ tmp_eval_kwargs.update({'sample_method': sample_n_method[1:], 'group_size': sample_n, 'beam_size':1}) # randomness from softmax
273
+ with torch.no_grad():
274
+ _seq, _sampleLogprobs = model(fc_feats, att_feats, att_masks, opt=tmp_eval_kwargs, mode='sample')
275
+ _sents = utils.decode_sequence(model.vocab, _seq)
276
+ for k, sent in enumerate(_sents):
277
+ entry = {'image_id': data['infos'][k // sample_n]['id'], 'caption': sent}
278
+ n_predictions.append(entry)
279
+ if verbose:
280
+ for entry in sorted(n_predictions[-fc_feats.shape[0] * sample_n:], key=lambda x: x['image_id']):
281
+ print('image %s: %s' %(entry['image_id'], entry['caption']))
captioning/utils/misc.py ADDED
@@ -0,0 +1,251 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import absolute_import
2
+ from __future__ import division
3
+ from __future__ import print_function
4
+
5
+ import collections
6
+ import torch
7
+ import torch.nn as nn
8
+ import numpy as np
9
+ import torch.optim as optim
10
+ import os
11
+
12
+ import torch.nn.functional as F
13
+
14
+ import six
15
+ from six.moves import cPickle
16
+
17
+ bad_endings = ['with','in','on','of','a','at','to','for','an','this','his','her','that']
18
+ bad_endings += ['the']
19
+
20
+
21
+ def pickle_load(f):
22
+ """ Load a pickle.
23
+ Parameters
24
+ ----------
25
+ f: file-like object
26
+ """
27
+ if six.PY3:
28
+ return cPickle.load(f, encoding='latin-1')
29
+ else:
30
+ return cPickle.load(f)
31
+
32
+
33
+ def pickle_dump(obj, f):
34
+ """ Dump a pickle.
35
+ Parameters
36
+ ----------
37
+ obj: pickled object
38
+ f: file-like object
39
+ """
40
+ if six.PY3:
41
+ return cPickle.dump(obj, f, protocol=2)
42
+ else:
43
+ return cPickle.dump(obj, f)
44
+
45
+
46
+ # modified from https://github.com/facebookresearch/detectron2/blob/master/detectron2/utils/comm.py
47
+ def serialize_to_tensor(data):
48
+ device = torch.device("cpu")
49
+
50
+ buffer = cPickle.dumps(data)
51
+ storage = torch.ByteStorage.from_buffer(buffer)
52
+ tensor = torch.ByteTensor(storage).to(device=device)
53
+ return tensor
54
+
55
+
56
+ def deserialize(tensor):
57
+ buffer = tensor.cpu().numpy().tobytes()
58
+ return cPickle.loads(buffer)
59
+
60
+
61
+ # Input: seq, N*D numpy array, with element 0 .. vocab_size. 0 is END token.
62
+ def decode_sequence(ix_to_word, seq):
63
+ # N, D = seq.size()
64
+ N, D = seq.shape
65
+ out = []
66
+ for i in range(N):
67
+ txt = ''
68
+ for j in range(D):
69
+ ix = seq[i,j]
70
+ if ix > 0 :
71
+ if j >= 1:
72
+ txt = txt + ' '
73
+ txt = txt + ix_to_word[str(ix.item())]
74
+ else:
75
+ break
76
+ if int(os.getenv('REMOVE_BAD_ENDINGS', '0')):
77
+ flag = 0
78
+ words = txt.split(' ')
79
+ for j in range(len(words)):
80
+ if words[-j-1] not in bad_endings:
81
+ flag = -j
82
+ break
83
+ txt = ' '.join(words[0:len(words)+flag])
84
+ out.append(txt.replace('@@ ', ''))
85
+ return out
86
+
87
+
88
+ def save_checkpoint(opt, model, infos, optimizer, histories=None, append=''):
89
+ if len(append) > 0:
90
+ append = '-' + append
91
+ # if checkpoint_path doesn't exist
92
+ if not os.path.isdir(opt.checkpoint_path):
93
+ os.makedirs(opt.checkpoint_path)
94
+ checkpoint_path = os.path.join(opt.checkpoint_path, 'model%s.pth' %(append))
95
+ torch.save(model.state_dict(), checkpoint_path)
96
+ print("model saved to {}".format(checkpoint_path))
97
+ optimizer_path = os.path.join(opt.checkpoint_path, 'optimizer%s.pth' %(append))
98
+ torch.save(optimizer.state_dict(), optimizer_path)
99
+ with open(os.path.join(opt.checkpoint_path, 'infos_'+opt.id+'%s.pkl' %(append)), 'wb') as f:
100
+ pickle_dump(infos, f)
101
+ if histories:
102
+ with open(os.path.join(opt.checkpoint_path, 'histories_'+opt.id+'%s.pkl' %(append)), 'wb') as f:
103
+ pickle_dump(histories, f)
104
+
105
+
106
+ def set_lr(optimizer, lr):
107
+ for group in optimizer.param_groups:
108
+ group['lr'] = lr
109
+
110
+ def get_lr(optimizer):
111
+ for group in optimizer.param_groups:
112
+ return group['lr']
113
+
114
+
115
+ def build_optimizer(params, opt):
116
+ if opt.optim == 'rmsprop':
117
+ return optim.RMSprop(params, opt.learning_rate, opt.optim_alpha, opt.optim_epsilon, weight_decay=opt.weight_decay)
118
+ elif opt.optim == 'adagrad':
119
+ return optim.Adagrad(params, opt.learning_rate, weight_decay=opt.weight_decay)
120
+ elif opt.optim == 'sgd':
121
+ return optim.SGD(params, opt.learning_rate, weight_decay=opt.weight_decay)
122
+ elif opt.optim == 'sgdm':
123
+ return optim.SGD(params, opt.learning_rate, opt.optim_alpha, weight_decay=opt.weight_decay)
124
+ elif opt.optim == 'sgdmom':
125
+ return optim.SGD(params, opt.learning_rate, opt.optim_alpha, weight_decay=opt.weight_decay, nesterov=True)
126
+ elif opt.optim == 'adam':
127
+ return optim.Adam(params, opt.learning_rate, (opt.optim_alpha, opt.optim_beta), opt.optim_epsilon, weight_decay=opt.weight_decay)
128
+ elif opt.optim == 'adamw':
129
+ return optim.AdamW(params, opt.learning_rate, (opt.optim_alpha, opt.optim_beta), opt.optim_epsilon, weight_decay=opt.weight_decay)
130
+ else:
131
+ raise Exception("bad option opt.optim: {}".format(opt.optim))
132
+
133
+
134
+ def penalty_builder(penalty_config):
135
+ if penalty_config == '':
136
+ return lambda x,y: y
137
+ pen_type, alpha = penalty_config.split('_')
138
+ alpha = float(alpha)
139
+ if pen_type == 'wu':
140
+ return lambda x,y: length_wu(x,y,alpha)
141
+ if pen_type == 'avg':
142
+ return lambda x,y: length_average(x,y,alpha)
143
+
144
+ def length_wu(length, logprobs, alpha=0.):
145
+ """
146
+ NMT length re-ranking score from
147
+ "Google's Neural Machine Translation System" :cite:`wu2016google`.
148
+ """
149
+
150
+ modifier = (((5 + length) ** alpha) /
151
+ ((5 + 1) ** alpha))
152
+ return (logprobs / modifier)
153
+
154
+ def length_average(length, logprobs, alpha=0.):
155
+ """
156
+ Returns the average probability of tokens in a sequence.
157
+ """
158
+ return logprobs / length
159
+
160
+
161
+ class NoamOpt(object):
162
+ "Optim wrapper that implements rate."
163
+ def __init__(self, model_size, factor, warmup, optimizer):
164
+ self.optimizer = optimizer
165
+ self._step = 0
166
+ self.warmup = warmup
167
+ self.factor = factor
168
+ self.model_size = model_size
169
+ self._rate = 0
170
+
171
+ def step(self):
172
+ "Update parameters and rate"
173
+ self._step += 1
174
+ rate = self.rate()
175
+ for p in self.optimizer.param_groups:
176
+ p['lr'] = rate
177
+ self._rate = rate
178
+ self.optimizer.step()
179
+
180
+ def rate(self, step = None):
181
+ "Implement `lrate` above"
182
+ if step is None:
183
+ step = self._step
184
+ return self.factor * \
185
+ (self.model_size ** (-0.5) *
186
+ min(step ** (-0.5), step * self.warmup ** (-1.5)))
187
+
188
+ def __getattr__(self, name):
189
+ return getattr(self.optimizer, name)
190
+
191
+ def state_dict(self):
192
+ state_dict = self.optimizer.state_dict()
193
+ state_dict['_step'] = self._step
194
+ return state_dict
195
+
196
+ def load_state_dict(self, state_dict):
197
+ if '_step' in state_dict:
198
+ self._step = state_dict['_step']
199
+ del state_dict['_step']
200
+ self.optimizer.load_state_dict(state_dict)
201
+
202
+ class ReduceLROnPlateau(object):
203
+ "Optim wrapper that implements rate."
204
+ def __init__(self, optimizer, mode='min', factor=0.1, patience=10, verbose=False, threshold=0.0001, threshold_mode='rel', cooldown=0, min_lr=0, eps=1e-08):
205
+ self.scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode, factor, patience, verbose, threshold, threshold_mode, cooldown, min_lr, eps)
206
+ self.optimizer = optimizer
207
+ self.current_lr = get_lr(optimizer)
208
+
209
+ def step(self):
210
+ "Update parameters and rate"
211
+ self.optimizer.step()
212
+
213
+ def scheduler_step(self, val):
214
+ self.scheduler.step(val)
215
+ self.current_lr = get_lr(self.optimizer)
216
+
217
+ def state_dict(self):
218
+ return {'current_lr':self.current_lr,
219
+ 'scheduler_state_dict': self.scheduler.state_dict(),
220
+ 'optimizer_state_dict': self.optimizer.state_dict()}
221
+
222
+ def load_state_dict(self, state_dict):
223
+ if 'current_lr' not in state_dict:
224
+ # it's normal optimizer
225
+ self.optimizer.load_state_dict(state_dict)
226
+ set_lr(self.optimizer, self.current_lr) # use the lr fromt the option
227
+ else:
228
+ # it's a schduler
229
+ self.current_lr = state_dict['current_lr']
230
+ self.scheduler.load_state_dict(state_dict['scheduler_state_dict'])
231
+ self.optimizer.load_state_dict(state_dict['optimizer_state_dict'])
232
+ # current_lr is actually useless in this case
233
+
234
+ def rate(self, step = None):
235
+ "Implement `lrate` above"
236
+ if step is None:
237
+ step = self._step
238
+ return self.factor * \
239
+ (self.model_size ** (-0.5) *
240
+ min(step ** (-0.5), step * self.warmup ** (-1.5)))
241
+
242
+ def __getattr__(self, name):
243
+ return getattr(self.optimizer, name)
244
+
245
+ def get_std_opt(model, optim_func='adam', factor=1, warmup=2000):
246
+ # return NoamOpt(model.tgt_embed[0].d_model, 2, 4000,
247
+ # torch.optim.Adam(model.parameters(), lr=0, betas=(0.9, 0.98), eps=1e-9))
248
+ optim_func = dict(adam=torch.optim.Adam,
249
+ adamw=torch.optim.AdamW)[optim_func]
250
+ return NoamOpt(model.d_model, factor, warmup,
251
+ optim_func(model.parameters(), lr=0, betas=(0.9, 0.98), eps=1e-9))
captioning/utils/opts.py ADDED
@@ -0,0 +1,412 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import print_function
2
+ import argparse
3
+
4
+
5
+ def if_use_feat(caption_model):
6
+ # Decide if load attention feature according to caption model
7
+ if caption_model in ['show_tell', 'all_img', 'fc', 'newfc']:
8
+ use_att, use_fc = False, True
9
+ elif caption_model == 'language_model':
10
+ use_att, use_fc = False, False
11
+ elif caption_model in ['updown', 'topdown']:
12
+ use_fc, use_att = True, True
13
+ else:
14
+ use_att, use_fc = True, False
15
+ return use_fc, use_att
16
+
17
+ import pprint
18
+ class Config(object):
19
+ def __init__(self, **kwargs):
20
+ """Configuration Class: set kwargs as class attributes with setattr"""
21
+ for k, v in kwargs.items():
22
+ setattr(self, k, v)
23
+
24
+ @property
25
+ def config_str(self):
26
+ return pprint.pformat(self.__dict__)
27
+
28
+ def __repr__(self):
29
+ """Pretty-print configurations in alphabetical order"""
30
+ config_str = 'Configurations\n'
31
+ config_str += self.config_str
32
+ return config_str
33
+
34
+
35
+ def parse_opt(parse=True, **optional_kwargs):
36
+ parser = argparse.ArgumentParser()
37
+ # Data input settings
38
+ parser.add_argument('--input_json', type=str, default='data/coco.json',
39
+ help='path to the json file containing additional info and vocab')
40
+ parser.add_argument('--input_fc_dir', type=str, default='data/cocotalk_fc',
41
+ help='path to the directory containing the preprocessed fc feats')
42
+ parser.add_argument('--input_att_dir', type=str, default='data/cocotalk_att',
43
+ help='path to the directory containing the preprocessed att feats')
44
+ parser.add_argument('--input_box_dir', type=str, default='data/cocotalk_box',
45
+ help='path to the directory containing the boxes of att feats')
46
+ parser.add_argument('--input_label_h5', type=str, default='data/coco_label.h5',
47
+ help='path to the h5file containing the preprocessed dataset')
48
+ parser.add_argument('--data_in_memory', action='store_true',
49
+ help='True if we want to save the features in memory')
50
+ parser.add_argument('--start_from', type=str, default=None,
51
+ help="""continue training from saved model at this path. Path must contain files saved by previous training process:
52
+ 'infos.pkl' : configuration;
53
+ 'model.pth' : weights
54
+ """)
55
+ parser.add_argument('--cached_tokens', type=str, default='coco-train-idxs',
56
+ help='Cached token file for calculating cider score during self critical training.')
57
+
58
+ # Model settings
59
+ parser.add_argument('--caption_model', type=str, default="show_tell",
60
+ help='show_tell, show_attend_tell, all_img, fc, att2in, att2in2, att2all2, adaatt, adaattmo, updown, stackatt, denseatt, transformer')
61
+ parser.add_argument('--rnn_size', type=int, default=512,
62
+ help='size of the rnn in number of hidden nodes in each layer')
63
+ parser.add_argument('--num_layers', type=int, default=1,
64
+ help='number of layers in the RNN')
65
+ parser.add_argument('--rnn_type', type=str, default='lstm',
66
+ help='rnn, gru, or lstm')
67
+ parser.add_argument('--input_encoding_size', type=int, default=512,
68
+ help='the encoding size of each token in the vocabulary, and the image.')
69
+ parser.add_argument('--att_hid_size', type=int, default=512,
70
+ help='the hidden size of the attention MLP; only useful in show_attend_tell; 0 if not using hidden layer')
71
+ parser.add_argument('--fc_feat_size', type=int, default=2048,
72
+ help='2048 for resnet, 4096 for vgg')
73
+ parser.add_argument('--att_feat_size', type=int, default=2048,
74
+ help='2048 for resnet, 512 for vgg')
75
+ parser.add_argument('--logit_layers', type=int, default=1,
76
+ help='number of layers in the RNN')
77
+
78
+
79
+ parser.add_argument('--use_bn', type=int, default=0,
80
+ help='If 1, then do batch_normalization first in att_embed, if 2 then do bn both in the beginning and the end of att_embed')
81
+
82
+ # feature manipulation
83
+ parser.add_argument('--norm_att_feat', type=int, default=0,
84
+ help='If normalize attention features')
85
+ parser.add_argument('--use_box', type=int, default=0,
86
+ help='If use box features')
87
+ parser.add_argument('--norm_box_feat', type=int, default=0,
88
+ help='If use box, do we normalize box feature')
89
+
90
+ # Optimization: General
91
+ parser.add_argument('--max_epochs', type=int, default=-1,
92
+ help='number of epochs')
93
+ parser.add_argument('--batch_size', type=int, default=16,
94
+ help='minibatch size')
95
+ parser.add_argument('--grad_clip_mode', type=str, default='value',
96
+ help='value or norm')
97
+ parser.add_argument('--grad_clip_value', type=float, default=0.1,
98
+ help='clip gradients at this value/max_norm, 0 means no clipping')
99
+ parser.add_argument('--drop_prob_lm', type=float, default=0.5,
100
+ help='strength of dropout in the Language Model RNN')
101
+ parser.add_argument('--self_critical_after', type=int, default=-1,
102
+ help='After what epoch do we start finetuning the CNN? (-1 = disable; never finetune, 0 = finetune from start)')
103
+ parser.add_argument('--seq_per_img', type=int, default=5,
104
+ help='number of captions to sample for each image during training. Done for efficiency since CNN forward pass is expensive. E.g. coco has 5 sents/image')
105
+
106
+ parser.add_argument('--verbose', type=int, default=0)
107
+
108
+ # Sample related
109
+ add_eval_sample_opts(parser)
110
+
111
+ #Optimization: for the Language Model
112
+ parser.add_argument('--optim', type=str, default='adam',
113
+ help='what update to use? rmsprop|sgd|sgdmom|adagrad|adam|adamw')
114
+ parser.add_argument('--learning_rate', type=float, default=4e-4,
115
+ help='learning rate')
116
+ parser.add_argument('--learning_rate_decay_start', type=int, default=-1,
117
+ help='at what iteration to start decaying learning rate? (-1 = dont) (in epoch)')
118
+ parser.add_argument('--learning_rate_decay_every', type=int, default=3,
119
+ help='every how many iterations thereafter to drop LR?(in epoch)')
120
+ parser.add_argument('--learning_rate_decay_rate', type=float, default=0.8,
121
+ help='every how many iterations thereafter to drop LR?(in epoch)')
122
+ parser.add_argument('--optim_alpha', type=float, default=0.9,
123
+ help='alpha for adam')
124
+ parser.add_argument('--optim_beta', type=float, default=0.999,
125
+ help='beta used for adam')
126
+ parser.add_argument('--optim_epsilon', type=float, default=1e-8,
127
+ help='epsilon that goes into denominator for smoothing')
128
+ parser.add_argument('--weight_decay', type=float, default=0,
129
+ help='weight_decay')
130
+ # Transformer
131
+ parser.add_argument('--label_smoothing', type=float, default=0,
132
+ help='')
133
+ parser.add_argument('--noamopt', action='store_true',
134
+ help='')
135
+ parser.add_argument('--noamopt_warmup', type=int, default=2000,
136
+ help='')
137
+ parser.add_argument('--noamopt_factor', type=float, default=1,
138
+ help='')
139
+ parser.add_argument('--reduce_on_plateau', action='store_true',
140
+ help='')
141
+ parser.add_argument('--reduce_on_plateau_factor', type=float, default=0.5,
142
+ help='')
143
+ parser.add_argument('--reduce_on_plateau_patience', type=int, default=3,
144
+ help='')
145
+ parser.add_argument('--cached_transformer', action='store_true',
146
+ help='')
147
+
148
+
149
+ parser.add_argument('--use_warmup', action='store_true',
150
+ help='warm up the learing rate?')
151
+
152
+ parser.add_argument('--scheduled_sampling_start', type=int, default=-1,
153
+ help='at what iteration to start decay gt probability')
154
+ parser.add_argument('--scheduled_sampling_increase_every', type=int, default=5,
155
+ help='every how many iterations thereafter to gt probability')
156
+ parser.add_argument('--scheduled_sampling_increase_prob', type=float, default=0.05,
157
+ help='How much to update the prob')
158
+ parser.add_argument('--scheduled_sampling_max_prob', type=float, default=0.25,
159
+ help='Maximum scheduled sampling prob.')
160
+
161
+
162
+ # Evaluation/Checkpointing
163
+ parser.add_argument('--val_images_use', type=int, default=3200,
164
+ help='how many images to use when periodically evaluating the validation loss? (-1 = all)')
165
+ parser.add_argument('--save_checkpoint_every', type=int, default=2500,
166
+ help='how often to save a model checkpoint (in iterations)?')
167
+ parser.add_argument('--save_every_epoch', action='store_true',
168
+ help='Save checkpoint every epoch, will overwrite save_checkpoint_every')
169
+ parser.add_argument('--save_history_ckpt', type=int, default=0,
170
+ help='If save checkpoints at every save point')
171
+ parser.add_argument('--checkpoint_path', type=str, default=None,
172
+ help='directory to store checkpointed models')
173
+ parser.add_argument('--language_eval', type=int, default=0,
174
+ help='Evaluate language as well (1 = yes, 0 = no)? BLEU/CIDEr/METEOR/ROUGE_L? requires coco-caption code from Github.')
175
+ parser.add_argument('--losses_log_every', type=int, default=25,
176
+ help='How often do we snapshot losses, for inclusion in the progress dump? (0 = disable)')
177
+ parser.add_argument('--load_best_score', type=int, default=1,
178
+ help='Do we load previous best score when resuming training.')
179
+
180
+ # misc
181
+ parser.add_argument('--id', type=str, default='',
182
+ help='an id identifying this run/job. used in cross-val and appended when writing progress files')
183
+ parser.add_argument('--train_only', type=int, default=0,
184
+ help='if true then use 80k, else use 110k')
185
+
186
+
187
+ # Reward
188
+ parser.add_argument('--cider_reward_weight', type=float, default=1,
189
+ help='The reward weight from cider')
190
+ parser.add_argument('--bleu_reward_weight', type=float, default=0,
191
+ help='The reward weight from bleu4')
192
+
193
+ # Reward
194
+ parser.add_argument('--clipscore_reward_weight', type=float, default=1,
195
+ help='The reward weight from clipscore')
196
+ parser.add_argument('--use_clipscore', type=float, default=0,
197
+ help='Use CLIPScore')
198
+ parser.add_argument('--clipscore_mode', type=str, default='clip_s',
199
+ help='Which CLIPScore to use: clip_s|refclip_s')
200
+
201
+
202
+ # Structure_loss
203
+ parser.add_argument('--structure_loss_weight', type=float, default=1,
204
+ help='')
205
+ parser.add_argument('--structure_after', type=int, default=-1,
206
+ help='T')
207
+ parser.add_argument('--structure_loss_type', type=str, default='seqnll',
208
+ help='')
209
+ parser.add_argument('--struc_use_logsoftmax', action='store_true', help='')
210
+ parser.add_argument('--entropy_reward_weight', type=float, default=0,
211
+ help='Entropy reward, seems very interesting')
212
+ parser.add_argument('--self_cider_reward_weight', type=float, default=0,
213
+ help='self cider reward')
214
+
215
+ # Used for self critical or structure. Used when sampling is need during training
216
+ parser.add_argument('--train_sample_n', type=int, default=16,
217
+ help='The reward weight from cider')
218
+ parser.add_argument('--train_sample_method', type=str, default='sample',
219
+ help='')
220
+ parser.add_argument('--train_beam_size', type=int, default=1,
221
+ help='')
222
+
223
+ # Used for self critical
224
+ parser.add_argument('--sc_sample_method', type=str, default='greedy',
225
+ help='')
226
+ parser.add_argument('--sc_beam_size', type=int, default=1,
227
+ help='')
228
+
229
+
230
+ # For diversity evaluation during training
231
+ add_diversity_opts(parser)
232
+
233
+
234
+ # config
235
+ parser.add_argument('--cfg', type=str, default=None,
236
+ help='configuration; similar to what is used in detectron')
237
+ parser.add_argument(
238
+ '--set_cfgs', dest='set_cfgs',
239
+ help='Set config keys. Key value sequence seperate by whitespace.'
240
+ 'e.g. [key] [value] [key] [value]\n This has higher priority'
241
+ 'than cfg file but lower than other args. (You can only overwrite'
242
+ 'arguments that have alerady been defined in config file.)',
243
+ default=[], nargs='+')
244
+ # How will config be used
245
+ # 1) read cfg argument, and load the cfg file if it's not None
246
+ # 2) Overwrite cfg argument with set_cfgs
247
+ # 3) parse config argument to args.
248
+ # 4) in the end, parse command line argument and overwrite args
249
+
250
+ # step 1: read cfg_fn
251
+ # args = parser.parse_args()
252
+ # Parse the arguments.
253
+ if parse:
254
+ args = parser.parse_args()
255
+ # For interative engironmnet (ex. jupyter)
256
+ else:
257
+ args = parser.parse_known_args()[0]
258
+ # print(args)
259
+
260
+ # Namespace => Dictionary
261
+ kwargs = vars(args)
262
+ # for k, v in optional_kwargs.items():
263
+ # setattr(args, k, v)
264
+ kwargs.update(optional_kwargs)
265
+
266
+ args = Config(**kwargs)
267
+
268
+
269
+ if args.cfg is not None or args.set_cfgs is not None:
270
+ from .config import CfgNode
271
+ if args.cfg is not None:
272
+ # print('Read Cfg')
273
+ cn = CfgNode(CfgNode.load_yaml_with_base(args.cfg))
274
+ # print(cn)
275
+ else:
276
+ cn = CfgNode()
277
+ if args.set_cfgs is not None:
278
+ cn.merge_from_list(args.set_cfgs)
279
+ for k,v in cn.items():
280
+ if not hasattr(args, k):
281
+ import os
282
+ if 'LOCAL_RANK' in os.environ and os.environ['LOCAL_RANK'] != '0':
283
+ pass
284
+ else:
285
+ print('Warning: key %s not in args' % k)
286
+
287
+ setattr(args, k, v)
288
+
289
+ if parse:
290
+ args = parser.parse_args(namespace=args)
291
+ else:
292
+ args = parser.parse_known_args(namespace=args)[0]
293
+
294
+ # Check if args are valid
295
+ assert args.rnn_size > 0, "rnn_size should be greater than 0"
296
+ assert args.num_layers > 0, "num_layers should be greater than 0"
297
+ assert args.input_encoding_size > 0, "input_encoding_size should be greater than 0"
298
+ assert args.batch_size > 0, "batch_size should be greater than 0"
299
+ assert args.drop_prob_lm >= 0 and args.drop_prob_lm < 1, "drop_prob_lm should be between 0 and 1"
300
+ assert args.seq_per_img > 0, "seq_per_img should be greater than 0"
301
+ assert args.beam_size > 0, "beam_size should be greater than 0"
302
+ assert args.save_checkpoint_every > 0, "save_checkpoint_every should be greater than 0"
303
+ assert args.losses_log_every > 0, "losses_log_every should be greater than 0"
304
+ assert args.language_eval == 0 or args.language_eval == 1, "language_eval should be 0 or 1"
305
+ assert args.load_best_score == 0 or args.load_best_score == 1, "language_eval should be 0 or 1"
306
+ assert args.train_only == 0 or args.train_only == 1, "language_eval should be 0 or 1"
307
+
308
+ # default value for start_from and checkpoint_path
309
+ args.checkpoint_path = args.checkpoint_path or './log_%s' %args.id
310
+ args.start_from = args.start_from or args.checkpoint_path
311
+
312
+ # Deal with feature things before anything
313
+ args.use_fc, args.use_att = if_use_feat(args.caption_model)
314
+ if args.use_box: args.att_feat_size = args.att_feat_size + 5
315
+
316
+ return args
317
+
318
+
319
+ def add_eval_options(parser):
320
+ # Basic options
321
+ parser.add_argument('--batch_size', type=int, default=0,
322
+ help='if > 0 then overrule, otherwise load from checkpoint.')
323
+ parser.add_argument('--num_images', type=int, default=-1,
324
+ help='how many images to use when periodically evaluating the loss? (-1 = all)')
325
+ parser.add_argument('--language_eval', type=int, default=0,
326
+ help='Evaluate language as well (1 = yes, 0 = no)? BLEU/CIDEr/METEOR/ROUGE_L? requires coco-caption code from Github.')
327
+ parser.add_argument('--dump_images', type=int, default=1,
328
+ help='Dump images into vis/imgs folder for vis? (1=yes,0=no)')
329
+ parser.add_argument('--dump_json', type=int, default=1,
330
+ help='Dump json with predictions into vis folder? (1=yes,0=no)')
331
+ parser.add_argument('--dump_path', type=int, default=0,
332
+ help='Write image paths along with predictions into vis json? (1=yes,0=no)')
333
+
334
+ # Sampling options
335
+ add_eval_sample_opts(parser)
336
+
337
+ # For evaluation on a folder of images:
338
+ parser.add_argument('--image_folder', type=str, default='',
339
+ help='If this is nonempty then will predict on the images in this folder path')
340
+ parser.add_argument('--image_root', type=str, default='',
341
+ help='In case the image paths have to be preprended with a root path to an image folder')
342
+ # For evaluation on MSCOCO images from some split:
343
+ parser.add_argument('--input_fc_dir', type=str, default='',
344
+ help='path to the h5file containing the preprocessed dataset')
345
+ parser.add_argument('--input_att_dir', type=str, default='',
346
+ help='path to the h5file containing the preprocessed dataset')
347
+ parser.add_argument('--input_box_dir', type=str, default='',
348
+ help='path to the h5file containing the preprocessed dataset')
349
+ parser.add_argument('--input_label_h5', type=str, default='',
350
+ help='path to the h5file containing the preprocessed dataset')
351
+ parser.add_argument('--input_json', type=str, default='',
352
+ help='path to the json file containing additional info and vocab. empty = fetch from model checkpoint.')
353
+ parser.add_argument('--split', type=str, default='test',
354
+ help='if running on MSCOCO images, which split to use: val|test|train')
355
+ parser.add_argument('--coco_json', type=str, default='',
356
+ help='if nonempty then use this file in DataLoaderRaw (see docs there). Used only in MSCOCO test evaluation, where we have a specific json file of only test set images.')
357
+ # misc
358
+ parser.add_argument('--id', type=str, default='',
359
+ help='an id identifying this run/job. used only if language_eval = 1 for appending to intermediate files')
360
+ parser.add_argument('--verbose_beam', type=int, default=1,
361
+ help='if we need to print out all beam search beams.')
362
+ parser.add_argument('--verbose_loss', type=int, default=0,
363
+ help='If calculate loss using ground truth during evaluation')
364
+
365
+ def add_diversity_opts(parser):
366
+ parser.add_argument('--sample_n', type=int, default=1,
367
+ help='Diverse sampling')
368
+ parser.add_argument('--sample_n_method', type=str, default='sample',
369
+ help='sample, bs, dbs, gumbel, topk, dgreedy, dsample, dtopk, dtopp')
370
+ parser.add_argument('--eval_oracle', type=int, default=1,
371
+ help='if we need to calculate loss.')
372
+
373
+
374
+ # Sampling related options
375
+ def add_eval_sample_opts(parser):
376
+ parser.add_argument('--sample_method', type=str, default='greedy',
377
+ help='greedy; sample; gumbel; top<int>, top<0-1>')
378
+ parser.add_argument('--beam_size', type=int, default=1,
379
+ help='used when sample_method = greedy, indicates number of beams in beam search. Usually 2 or 3 works well. More is not better. Set this to 1 for faster runtime but a bit worse performance.')
380
+ parser.add_argument('--max_length', type=int, default=20,
381
+ help='Maximum length during sampling')
382
+ parser.add_argument('--length_penalty', type=str, default='',
383
+ help='wu_X or avg_X, X is the alpha')
384
+ parser.add_argument('--group_size', type=int, default=1,
385
+ help='used for diverse beam search. if group_size is 1, then it\'s normal beam search')
386
+ parser.add_argument('--diversity_lambda', type=float, default=0.5,
387
+ help='used for diverse beam search. Usually from 0.2 to 0.8. Higher value of lambda produces a more diverse list')
388
+ parser.add_argument('--temperature', type=float, default=1.0,
389
+ help='temperature when sampling from distributions (i.e. when sample_method = sample). Lower = "safer" predictions.')
390
+ parser.add_argument('--decoding_constraint', type=int, default=0,
391
+ help='If 1, not allowing same word in a row')
392
+ parser.add_argument('--block_trigrams', type=int, default=0,
393
+ help='block repeated trigram.')
394
+ parser.add_argument('--remove_bad_endings', type=int, default=0,
395
+ help='Remove bad endings')
396
+ parser.add_argument('--suppress_UNK', type=int, default=1,
397
+ help='Not predicting UNK')
398
+
399
+
400
+ if __name__ == '__main__':
401
+ import sys
402
+ sys.argv = [sys.argv[0]]
403
+ args = parse_opt()
404
+ print(args)
405
+ print()
406
+ sys.argv = [sys.argv[0], '--cfg', 'configs/updown_long.yml']
407
+ args1 = parse_opt()
408
+ print(dict(set(vars(args1).items()) - set(vars(args).items())))
409
+ print()
410
+ sys.argv = [sys.argv[0], '--cfg', 'configs/updown_long.yml', '--caption_model', 'att2in2']
411
+ args2 = parse_opt()
412
+ print(dict(set(vars(args2).items()) - set(vars(args1).items())))
captioning/utils/resnet.py ADDED
@@ -0,0 +1,71 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torchvision.models.resnet
4
+ from torchvision.models.resnet import BasicBlock, Bottleneck
5
+
6
+ class ResNet(torchvision.models.resnet.ResNet):
7
+ def __init__(self, block, layers, num_classes=1000):
8
+ super(ResNet, self).__init__(block, layers, num_classes)
9
+ self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=0, ceil_mode=True) # change
10
+ for i in range(2, 5):
11
+ getattr(self, 'layer%d'%i)[0].conv1.stride = (2,2)
12
+ getattr(self, 'layer%d'%i)[0].conv2.stride = (1,1)
13
+
14
+ def resnet18(pretrained=False):
15
+ """Constructs a ResNet-18 model.
16
+
17
+ Args:
18
+ pretrained (bool): If True, returns a model pre-trained on ImageNet
19
+ """
20
+ model = ResNet(BasicBlock, [2, 2, 2, 2])
21
+ if pretrained:
22
+ model.load_state_dict(model_zoo.load_url(model_urls['resnet18']))
23
+ return model
24
+
25
+
26
+ def resnet34(pretrained=False):
27
+ """Constructs a ResNet-34 model.
28
+
29
+ Args:
30
+ pretrained (bool): If True, returns a model pre-trained on ImageNet
31
+ """
32
+ model = ResNet(BasicBlock, [3, 4, 6, 3])
33
+ if pretrained:
34
+ model.load_state_dict(model_zoo.load_url(model_urls['resnet34']))
35
+ return model
36
+
37
+
38
+ def resnet50(pretrained=False):
39
+ """Constructs a ResNet-50 model.
40
+
41
+ Args:
42
+ pretrained (bool): If True, returns a model pre-trained on ImageNet
43
+ """
44
+ model = ResNet(Bottleneck, [3, 4, 6, 3])
45
+ if pretrained:
46
+ model.load_state_dict(model_zoo.load_url(model_urls['resnet50']))
47
+ return model
48
+
49
+
50
+ def resnet101(pretrained=False):
51
+ """Constructs a ResNet-101 model.
52
+
53
+ Args:
54
+ pretrained (bool): If True, returns a model pre-trained on ImageNet
55
+ """
56
+ model = ResNet(Bottleneck, [3, 4, 23, 3])
57
+ if pretrained:
58
+ model.load_state_dict(model_zoo.load_url(model_urls['resnet101']))
59
+ return model
60
+
61
+
62
+ def resnet152(pretrained=False):
63
+ """Constructs a ResNet-152 model.
64
+
65
+ Args:
66
+ pretrained (bool): If True, returns a model pre-trained on ImageNet
67
+ """
68
+ model = ResNet(Bottleneck, [3, 8, 36, 3])
69
+ if pretrained:
70
+ model.load_state_dict(model_zoo.load_url(model_urls['resnet152']))
71
+ return model
captioning/utils/resnet_utils.py ADDED
@@ -0,0 +1,27 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+
5
+ class myResnet(nn.Module):
6
+ def __init__(self, resnet):
7
+ super(myResnet, self).__init__()
8
+ self.resnet = resnet
9
+
10
+ def forward(self, img, att_size=14):
11
+ x = img.unsqueeze(0)
12
+
13
+ x = self.resnet.conv1(x)
14
+ x = self.resnet.bn1(x)
15
+ x = self.resnet.relu(x)
16
+ x = self.resnet.maxpool(x)
17
+
18
+ x = self.resnet.layer1(x)
19
+ x = self.resnet.layer2(x)
20
+ x = self.resnet.layer3(x)
21
+ x = self.resnet.layer4(x)
22
+
23
+ fc = x.mean(3).mean(2).squeeze()
24
+ att = F.adaptive_avg_pool2d(x,[att_size,att_size]).squeeze().permute(1, 2, 0)
25
+
26
+ return fc, att
27
+
captioning/utils/rewards.py ADDED
@@ -0,0 +1,392 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import absolute_import
2
+ from __future__ import division
3
+ from __future__ import print_function
4
+
5
+ import numpy as np
6
+ import time
7
+ from collections import OrderedDict
8
+ import torch
9
+
10
+ import sys
11
+ try:
12
+ sys.path.append("cider")
13
+ from pyciderevalcap.ciderD.ciderD import CiderD
14
+ from pyciderevalcap.cider.cider import Cider
15
+ sys.path.append("coco-caption")
16
+ from pycocoevalcap.bleu.bleu import Bleu
17
+ except:
18
+ print('cider or coco-caption missing')
19
+
20
+ CiderD_scorer = None
21
+ Cider_scorer = None
22
+ Bleu_scorer = None
23
+ #CiderD_scorer = CiderD(df='corpus')
24
+
25
+
26
+ from .misc import decode_sequence
27
+
28
+ def init_scorer(cached_tokens):
29
+ global CiderD_scorer
30
+ CiderD_scorer = CiderD_scorer or CiderD(df=cached_tokens)
31
+ global Cider_scorer
32
+ Cider_scorer = Cider_scorer or Cider(df=cached_tokens)
33
+ global Bleu_scorer
34
+ Bleu_scorer = Bleu_scorer or Bleu(4)
35
+
36
+ def array_to_str(arr):
37
+ out = ''
38
+ for i in range(len(arr)):
39
+ out += str(arr[i]) + ' '
40
+ if arr[i] == 0:
41
+ break
42
+ return out.strip()
43
+
44
+ def get_self_critical_reward(greedy_res, data_gts, gen_result, opt):
45
+ batch_size = len(data_gts)
46
+ gen_result_size = gen_result.shape[0]
47
+ seq_per_img = gen_result_size // len(data_gts) # gen_result_size = batch_size * seq_per_img
48
+ assert greedy_res.shape[0] == batch_size
49
+
50
+ res = OrderedDict()
51
+ gen_result = gen_result.data.cpu().numpy()
52
+ greedy_res = greedy_res.data.cpu().numpy()
53
+ for i in range(gen_result_size):
54
+ res[i] = [array_to_str(gen_result[i])]
55
+ for i in range(batch_size):
56
+ res[gen_result_size + i] = [array_to_str(greedy_res[i])]
57
+
58
+ gts = OrderedDict()
59
+ for i in range(len(data_gts)):
60
+ gts[i] = [array_to_str(data_gts[i][j]) for j in range(len(data_gts[i]))]
61
+
62
+ res_ = [{'image_id':i, 'caption': res[i]} for i in range(len(res))]
63
+ res__ = {i: res[i] for i in range(len(res_))}
64
+ gts_ = {i: gts[i // seq_per_img] for i in range(gen_result_size)}
65
+ gts_.update({i+gen_result_size: gts[i] for i in range(batch_size)})
66
+ if opt.cider_reward_weight > 0:
67
+ _, cider_scores = CiderD_scorer.compute_score(gts_, res_)
68
+ if hasattr(opt, 'verbose') and not opt.verbose:
69
+ pass
70
+ else:
71
+ print('Cider scores:', _)
72
+ else:
73
+ cider_scores = 0
74
+ if opt.bleu_reward_weight > 0:
75
+ _, bleu_scores = Bleu_scorer.compute_score(gts_, res__)
76
+ bleu_scores = np.array(bleu_scores[3])
77
+ if hasattr(opt, 'verbose') and not opt.verbose:
78
+ pass
79
+ else:
80
+ print('Bleu scores:', _[3])
81
+ else:
82
+ bleu_scores = 0
83
+ scores = opt.cider_reward_weight * cider_scores + opt.bleu_reward_weight * bleu_scores
84
+
85
+ unnormalized_reward_mean = scores[:gen_result_size].flatten().mean()
86
+
87
+ scores = scores[:gen_result_size].reshape(batch_size, seq_per_img) - scores[-batch_size:][:, np.newaxis]
88
+
89
+ scores = scores.reshape(gen_result_size)
90
+
91
+ rewards = np.repeat(scores[:, np.newaxis], gen_result.shape[1], 1)
92
+
93
+ return rewards, unnormalized_reward_mean
94
+
95
+
96
+ def get_self_critical_clipscore_reward(greedy_res, data_gts, gen_result, opt, clipscore_model, clip_vis_feats, vocab):
97
+ batch_size = len(data_gts)
98
+ gen_result_size = gen_result.shape[0]
99
+ seq_per_img = gen_result_size // len(data_gts) # gen_result_size = batch_size * seq_per_img
100
+ assert greedy_res.shape[0] == batch_size
101
+
102
+ B = batch_size
103
+ K = seq_per_img
104
+ L = gen_result.shape[1]
105
+ assert gen_result.shape == (B*K , L)
106
+
107
+ # res = OrderedDict()
108
+ # gen_result = gen_result.data.cpu().numpy()
109
+ # greedy_res = greedy_res.data.cpu().numpy()
110
+ # for i in range(gen_result_size):
111
+ # res[i] = [array_to_str(gen_result[i])]
112
+ # for i in range(batch_size):
113
+ # res[gen_result_size + i] = [array_to_str(greedy_res[i])]
114
+
115
+ # gts = OrderedDict()
116
+ # for i in range(len(data_gts)):
117
+ # gts[i] = [array_to_str(data_gts[i][j]) for j in range(len(data_gts[i]))]
118
+
119
+ # res_ = [{'image_id':i, 'caption': res[i]} for i in range(len(res))]
120
+ # res__ = {i: res[i] for i in range(len(res_))}
121
+ # gts_ = {i: gts[i // seq_per_img] for i in range(gen_result_size)}
122
+ # gts_.update({i+gen_result_size: gts[i] for i in range(batch_size)})
123
+
124
+ # res = []
125
+ # gen_result = gen_result.data.cpu().numpy()
126
+ # greedy_res = greedy_res.data.cpu().numpy()
127
+ # # for i in range(gen_result_size):
128
+ # # res.append(array_to_str(gen_result[i]))
129
+ # res.extend(decode_sequence(vocab, gen_result))
130
+
131
+
132
+ # # for i in range(batch_size):
133
+ # # res.append(array_to_str(greedy_res[i]))
134
+ # res.extend(decode_sequence(vocab, greedy_res))
135
+
136
+ if clipscore_model.mode == 'refclip_s':
137
+ gts = []
138
+ gts_valid_mask = []
139
+ max_n_refs = max([len(_gts) for _gts in data_gts])
140
+ for i in range(len(data_gts)):
141
+ _gts = decode_sequence(vocab, data_gts[i])
142
+ # pad references
143
+ n_ref = len(_gts)
144
+ _gts.extend([''] * (max_n_refs - n_ref))
145
+ gts.extend(_gts)
146
+ gts_valid_mask.extend([1] * n_ref + [0] * (max_n_refs - n_ref))
147
+ assert len(gts) == B * max_n_refs
148
+ assert len(gts_valid_mask) == B * max_n_refs
149
+
150
+ # print(gts)
151
+ # print(gts_valid_mask)
152
+ # exit()
153
+
154
+
155
+ # assert len(res) == B * K + B, len(res)
156
+
157
+ # print(res)
158
+ # exit()
159
+
160
+ if opt.clipscore_reward_weight > 0:
161
+ with torch.no_grad():
162
+ clipscore_model.eval()
163
+
164
+ # 1) calculate reward
165
+ gen_result = gen_result.data.cpu().numpy()
166
+ res = decode_sequence(vocab, gen_result)
167
+ assert len(res) == B * K, len(res)
168
+
169
+ # [B * K, dim)
170
+ if getattr(opt, 'use_grammar', False) and not getattr(opt, 'joint_out', False):
171
+ text_pre_feat = clipscore_model.text_extract(res, proj_norm=False)
172
+
173
+ grammar_logit = clipscore_model.grammar_score_head(text_pre_feat.view(-1, 512))
174
+ grammar_prob = torch.softmax(grammar_logit, dim=-1)[:, 1]
175
+ grammar_prob = grammar_prob.view(B*K).detach()
176
+
177
+ text_feat = clipscore_model.clip_model.text_projection(text_pre_feat)
178
+ text_feat = text_feat / text_feat.norm(dim=-1, keepdim=True)
179
+
180
+ else:
181
+ text_feat = clipscore_model.text_extract(res)
182
+
183
+
184
+ assert text_feat.size() == (B * K, 512), text_feat.size()
185
+ assert clip_vis_feats.size() == (B, 512), clip_vis_feats.size()
186
+
187
+ # [B * K, dim]
188
+ vis_feat = clip_vis_feats.view(B, 1, -1).expand(-1, K, -1).contiguous().view(B * K, -1)
189
+
190
+ clip_s = clipscore_model(text_feat=text_feat, img_feat=vis_feat, mode='clip_s')
191
+ clip_s = clip_s.view(B * K).detach()
192
+
193
+ if clipscore_model.mode == 'refclip_s':
194
+ # [B * n_ref, dim]
195
+ ref_text_feat = clipscore_model.text_extract(gts)
196
+ ref_text_mask = torch.tensor(gts_valid_mask, dtype=ref_text_feat.dtype, device=ref_text_feat.device)
197
+
198
+ assert ref_text_feat.size() == (B * max_n_refs, 512), ref_text_feat.size()
199
+ assert ref_text_mask.size() == (B * max_n_refs,), ref_text_mask.size()
200
+
201
+ # [B * K]
202
+ refclip_s = clipscore_model.calc_refclip_s(
203
+ text_feat=text_feat, img_feat=vis_feat,
204
+ ref_text_feat=ref_text_feat.view(B, 1, max_n_refs, -1).expand(-1, K, -1, -1).contiguous().view(B * K * max_n_refs, -1),
205
+ ref_text_mask=ref_text_mask.view(B, 1, max_n_refs).expand(-1, K, -1).contiguous().view(B * K * max_n_refs),
206
+ clip_s=clip_s)
207
+ refclip_s = refclip_s.view(B * K).detach()
208
+
209
+ # 2) calcualte reward for baseline (greedy)
210
+ greedy_res = greedy_res.data.cpu().numpy()
211
+ res = decode_sequence(vocab, greedy_res)
212
+ assert len(res) == B, len(res)
213
+
214
+ # [B, dim)
215
+
216
+ if getattr(opt, 'use_grammar', False) and getattr(opt, 'use_grammar_baseline', False) and not getattr(opt, 'joint_out', False):
217
+ text_pre_feat = clipscore_model.text_extract(res, proj_norm=False)
218
+
219
+ grammar_logit = clipscore_model.grammar_score_head(text_pre_feat.view(-1, 512))
220
+ grammar_prob_baseline = torch.softmax(grammar_logit, dim=-1)[:, 1]
221
+ grammar_prob_baseline = grammar_prob_baseline.view(B).detach()
222
+
223
+ text_feat = clipscore_model.clip_model.text_projection(text_pre_feat)
224
+ text_feat = text_feat / text_feat.norm(dim=-1, keepdim=True)
225
+ else:
226
+ text_feat = clipscore_model.text_extract(res)
227
+
228
+ assert text_feat.size() == (B, 512), text_feat.size()
229
+ assert clip_vis_feats.size() == (B, 512), clip_vis_feats.size()
230
+
231
+ vis_feat = clip_vis_feats.view(B, 512)
232
+
233
+ # [B]
234
+ clip_s_baseline = clipscore_model(text_feat=text_feat, img_feat=vis_feat, mode='clip_s')
235
+ clip_s_baseline = clip_s_baseline.view(B).detach()
236
+
237
+ if clipscore_model.mode == 'refclip_s':
238
+ # # [B * n_ref]
239
+ # ref_text_feat = clipscore_model.text_extract(gts)
240
+ # ref_text_mask = torch.tensor(gts_valid_mask, dtype=ref_text_feat.dtype, device=ref_text_feat.device)
241
+ # assert ref_text_feat.size() == (B * max_n_refs, 512), ref_text_feat.size()
242
+ # assert ref_text_mask.size() == (B * max_n_refs), ref_text_mask.size()
243
+
244
+ # [B]
245
+ refclip_s_baseline = clipscore_model.calc_refclip_s(
246
+ text_feat=text_feat, img_feat=vis_feat,
247
+ ref_text_feat=ref_text_feat,
248
+ ref_text_mask=ref_text_mask,
249
+ clip_s=clip_s_baseline)
250
+ refclip_s_baseline = refclip_s_baseline.view(B).detach()
251
+
252
+ if clipscore_model.mode == 'clip_s':
253
+ rewards = clip_s - clip_s_baseline.view(B, 1).expand(-1, K).contiguous().flatten()
254
+ unnormalized_mean_reward = clip_s.mean()
255
+ elif clipscore_model.mode == 'refclip_s':
256
+ rewards = refclip_s - refclip_s_baseline.view(B, 1).expand(-1, K).contiguous().flatten()
257
+ unnormalized_mean_reward = refclip_s.mean()
258
+
259
+ # # [B * K + B, dim)
260
+ # text_feat = clipscore_model.text_extract(res)
261
+ # assert text_feat.size() == (B * K + B, 512), text_feat.size()
262
+
263
+ # assert clip_vis_feats.size() == (B, 512), clip_vis_feats.size()
264
+
265
+ # # [B, dim] -> [B * K + B, dim]
266
+ # # vis_feat = clip_vis_feats.view(B, 1, -1).expand(-1, K + 1, -1).contiguous().view(B * (K + 1), -1)
267
+ # # vis_feat = clip_vis_feats.view(1, B, -1).expand(K + 1, -1, -1).contiguous().view((K + 1) * B, -1)
268
+
269
+ # # [B * K, dim]
270
+ # gen_vis_feat = clip_vis_feats.view(B, 1, -1).expand(-1, K, -1).contiguous().view(B * K, -1)
271
+ # # [B, dim]
272
+ # greedy_vis_feat = clip_vis_feats
273
+ # # [B * K + B, dim]
274
+ # vis_feat = torch.cat([gen_vis_feat, greedy_vis_feat], dim=0)
275
+
276
+ # # if clipscore_model.mode == 'clip_s':
277
+ # # [B * K + B, dim]
278
+ # clip_s = clipscore_model(text_feat=text_feat, img_feat=vis_feat)
279
+ # clip_s = clip_s.view(B * K + B).detach()
280
+
281
+
282
+ # if clipscore_model.mode == 'refclip_s':
283
+ # # [B * K, dim]
284
+ # ref_text_feat = clipscore_model.text_extract(gts)
285
+
286
+ # clipscore_scores = clipscore_model.calc_refclip_s(text_feat=text_feat, img_feat=vis_feat, ref_text_feat=ref_text_feat, clip_s=clip_s)
287
+ # clipscore_scores = clipscore_scores.view(B * K + B).detach()
288
+
289
+ if getattr(opt, 'use_grammar', False) and not getattr(opt, 'joint_out', False):
290
+
291
+ if getattr(opt, 'use_grammar_baseline', False):
292
+ grammar_rewards = grammar_prob - grammar_prob_baseline.view(B, 1).expand(-1, K).contiguous().flatten()
293
+ else:
294
+ grammar_rewards = grammar_prob
295
+ else:
296
+ grammar_rewards = None
297
+
298
+
299
+ if hasattr(opt, 'verbose') and not opt.verbose:
300
+ pass
301
+ else:
302
+ if clipscore_model.mode == 'clip_s':
303
+ print('CLIP-S:', rewards)
304
+ elif clipscore_model.mode == 'refclip_s':
305
+ print('RefCLIP-S:', rewards)
306
+ else:
307
+ rewards = torch.zeros(B, L)
308
+ unnormalized_mean_reward = None
309
+ grammar_rewards = None
310
+
311
+
312
+ rewards = opt.clipscore_reward_weight * rewards
313
+
314
+
315
+ # scores = scores[:gen_result_size].reshape(batch_size, seq_per_img) - scores[-batch_size:][:, np.newaxis]
316
+ # scores = scores.reshape(gen_result_size)
317
+ # rewards = np.repeat(scores[:, np.newaxis], gen_result.shape[1], 1)
318
+
319
+ # [B, K]
320
+ # scores = scores[:gen_result_size].reshape(B, K) - scores[-B:].unsqueeze(1)
321
+
322
+ # [B*K, L]
323
+ # rewards = scores.view(-1, 1).expand(-1, L).contiguous()
324
+ rewards = rewards.view(-1, 1).expand(-1, L).contiguous()
325
+
326
+ if getattr(opt, 'use_grammar', False) and not getattr(opt, 'joint_out', False):
327
+ grammar_rewards = grammar_rewards.view(-1, 1).expand(-1, L).contiguous()
328
+
329
+ return rewards, unnormalized_mean_reward, grammar_rewards
330
+
331
+ def get_scores(data_gts, gen_result, opt):
332
+ batch_size = gen_result.size(0)# batch_size = sample_size * seq_per_img
333
+ seq_per_img = batch_size // len(data_gts)
334
+
335
+ res = OrderedDict()
336
+
337
+ gen_result = gen_result.data.cpu().numpy()
338
+ for i in range(batch_size):
339
+ res[i] = [array_to_str(gen_result[i])]
340
+
341
+ gts = OrderedDict()
342
+ for i in range(len(data_gts)):
343
+ gts[i] = [array_to_str(data_gts[i][j]) for j in range(len(data_gts[i]))]
344
+
345
+ res_ = [{'image_id':i, 'caption': res[i]} for i in range(batch_size)]
346
+ res__ = {i: res[i] for i in range(batch_size)}
347
+ gts = {i: gts[i // seq_per_img] for i in range(batch_size)}
348
+ if opt.cider_reward_weight > 0:
349
+ _, cider_scores = CiderD_scorer.compute_score(gts, res_)
350
+ # print('Cider scores:', _)
351
+ if hasattr(opt, 'verbose') and not opt.verbose:
352
+ pass
353
+ else:
354
+ print('Cider scores:', _)
355
+ else:
356
+ cider_scores = 0
357
+ if opt.bleu_reward_weight > 0:
358
+ _, bleu_scores = Bleu_scorer.compute_score(gts, res__)
359
+ bleu_scores = np.array(bleu_scores[3])
360
+ # print('Bleu scores:', _[3])
361
+ if hasattr(opt, 'verbose') and not opt.verbose:
362
+ pass
363
+ else:
364
+ print('Bleu scores:', _[3])
365
+ else:
366
+ bleu_scores = 0
367
+
368
+ scores = opt.cider_reward_weight * cider_scores + opt.bleu_reward_weight * bleu_scores
369
+
370
+ return scores
371
+
372
+ def get_self_cider_scores(data_gts, gen_result, opt):
373
+ batch_size = gen_result.size(0)# batch_size = sample_size * seq_per_img
374
+ seq_per_img = batch_size // len(data_gts)
375
+
376
+ res = []
377
+
378
+ gen_result = gen_result.data.cpu().numpy()
379
+ for i in range(batch_size):
380
+ res.append(array_to_str(gen_result[i]))
381
+
382
+ scores = []
383
+ for i in range(len(data_gts)):
384
+ tmp = Cider_scorer.my_self_cider([res[i*seq_per_img:(i+1)*seq_per_img]])
385
+ def get_div(eigvals):
386
+ eigvals = np.clip(eigvals, 0, None)
387
+ return -np.log(np.sqrt(eigvals[-1]) / (np.sqrt(eigvals).sum())) / np.log(len(eigvals))
388
+ scores.append(get_div(np.linalg.eigvalsh(tmp[0]/10)))
389
+
390
+ scores = np.array(scores)
391
+
392
+ return scores
captioning/utils/utils.py ADDED
@@ -0,0 +1,138 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import re
2
+ import numpy as np
3
+ import torch
4
+ import torch.distributed as dist
5
+ import collections
6
+ import logging
7
+
8
+ def get_area(pos):
9
+ """
10
+ Args
11
+ pos: [B, N, 4]
12
+ (x1, x2, y1, y2)
13
+
14
+ Return
15
+ area : [B, N]
16
+ """
17
+ # [B, N]
18
+ height = pos[:, :, 3] - pos[:, :, 2]
19
+ width = pos[:, :, 1] - pos[:, :, 0]
20
+ area = height * width
21
+ return area
22
+
23
+ def get_relative_distance(pos):
24
+ """
25
+ Args
26
+ pos: [B, N, 4]
27
+ (x1, x2, y1, y2)
28
+
29
+ Return
30
+ out : [B, N, N, 4]
31
+ """
32
+ # B, N = pos.size()[:-1]
33
+
34
+ # [B, N, N, 4]
35
+ relative_distance = pos.unsqueeze(1) - pos.unsqueeze(2)
36
+
37
+ return relative_distance
38
+
39
+
40
+ class LossMeter(object):
41
+ def __init__(self, maxlen=100):
42
+ """Computes and stores the running average"""
43
+ self.vals = collections.deque([], maxlen=maxlen)
44
+
45
+ def __len__(self):
46
+ return len(self.vals)
47
+
48
+ def update(self, new_val):
49
+ self.vals.append(new_val)
50
+
51
+ @property
52
+ def val(self):
53
+ return sum(self.vals) / len(self.vals)
54
+
55
+ def __repr__(self):
56
+ return str(self.val)
57
+
58
+
59
+ def count_parameters(model):
60
+ return sum(p.numel() for p in model.parameters() if p.requires_grad)
61
+
62
+
63
+ def load_state_dict(state_dict_path, loc='cpu'):
64
+ state_dict = torch.load(state_dict_path, map_location=loc)
65
+ # Change Multi GPU to single GPU
66
+ original_keys = list(state_dict.keys())
67
+ for key in original_keys:
68
+ if key.startswith("module."):
69
+ new_key = key[len("module."):]
70
+ state_dict[new_key] = state_dict.pop(key)
71
+ return state_dict
72
+
73
+
74
+ def set_global_logging_level(level=logging.ERROR, prefices=[""]):
75
+ """
76
+ Override logging levels of different modules based on their name as a prefix.
77
+ It needs to be invoked after the modules have been loaded so that their loggers have been initialized.
78
+
79
+ Args:
80
+ - level: desired level. e.g. logging.INFO. Optional. Default is logging.ERROR
81
+ - prefices: list of one or more str prefices to match (e.g. ["transformers", "torch"]). Optional.
82
+ Default is `[""]` to match all active loggers.
83
+ The match is a case-sensitive `module_name.startswith(prefix)`
84
+ """
85
+ prefix_re = re.compile(fr'^(?:{ "|".join(prefices) })')
86
+ for name in logging.root.manager.loggerDict:
87
+ if re.match(prefix_re, name):
88
+ logging.getLogger(name).setLevel(level)
89
+
90
+
91
+ def get_iou(anchors, gt_boxes):
92
+ """
93
+ anchors: (N, 4) torch floattensor
94
+ gt_boxes: (K, 4) torch floattensor
95
+ overlaps: (N, K) ndarray of overlap between boxes and query_boxes
96
+ """
97
+ N = anchors.size(0)
98
+
99
+ if gt_boxes.size() == (4,):
100
+ gt_boxes = gt_boxes.view(1, 4)
101
+ K = gt_boxes.size(0)
102
+
103
+ gt_boxes_area = (
104
+ (gt_boxes[:, 2] - gt_boxes[:, 0] + 1) *
105
+ (gt_boxes[:, 3] - gt_boxes[:, 1] + 1)
106
+ ).view(1, K)
107
+
108
+ anchors_area = (
109
+ (anchors[:, 2] - anchors[:, 0] + 1) *
110
+ (anchors[:, 3] - anchors[:, 1] + 1)
111
+ ).view(N, 1)
112
+
113
+ boxes = anchors.view(N, 1, 4).expand(N, K, 4)
114
+ query_boxes = gt_boxes.view(1, K, 4).expand(N, K, 4)
115
+
116
+ iw = (
117
+ torch.min(boxes[:, :, 2], query_boxes[:, :, 2])
118
+ - torch.max(boxes[:, :, 0], query_boxes[:, :, 0])
119
+ + 1
120
+ )
121
+ iw[iw < 0] = 0
122
+
123
+ ih = (
124
+ torch.min(boxes[:, :, 3], query_boxes[:, :, 3])
125
+ - torch.max(boxes[:, :, 1], query_boxes[:, :, 1])
126
+ + 1
127
+ )
128
+ ih[ih < 0] = 0
129
+
130
+ ua = anchors_area + gt_boxes_area - (iw * ih)
131
+ overlaps = iw * ih / ua
132
+
133
+ return overlaps
134
+
135
+
136
+ def xywh_to_xyxy(boxes):
137
+ """Convert [x y w h] box format to [x1 y1 x2 y2] format."""
138
+ return np.hstack((boxes[:, 0:2], boxes[:, 0:2] + boxes[:, 2:4] - 1))
clip/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ from .clip import *
clip/clip.py ADDED
@@ -0,0 +1,193 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import hashlib
2
+ import os
3
+ import urllib
4
+ import warnings
5
+ from typing import Union, List
6
+
7
+ import torch
8
+ from PIL import Image
9
+ from torchvision.transforms import Compose, Resize, CenterCrop, ToTensor, Normalize
10
+ from tqdm import tqdm
11
+
12
+ from .model import build_model
13
+ from .simple_tokenizer import SimpleTokenizer as _Tokenizer
14
+
15
+ __all__ = ["available_models", "load", "tokenize"]
16
+ _tokenizer = _Tokenizer()
17
+
18
+ _MODELS = {
19
+ "RN50": "https://openaipublic.azureedge.net/clip/models/afeb0e10f9e5a86da6080e35cf09123aca3b358a0c3e3b6c78a7b63bc04b6762/RN50.pt",
20
+ "RN101": "https://openaipublic.azureedge.net/clip/models/8fa8567bab74a42d41c5915025a8e4538c3bdbe8804a470a72f30b0d94fab599/RN101.pt",
21
+ "RN50x4": "https://openaipublic.azureedge.net/clip/models/7e526bd135e493cef0776de27d5f42653e6b4c8bf9e0f653bb11773263205fdd/RN50x4.pt",
22
+ "ViT-B/32": "https://openaipublic.azureedge.net/clip/models/40d365715913c9da98579312b702a82c18be219cc2a73407c4526f58eba950af/ViT-B-32.pt",
23
+ }
24
+
25
+
26
+ def _download(url: str, root: str = os.path.expanduser("~/.cache/clip")):
27
+ os.makedirs(root, exist_ok=True)
28
+ filename = os.path.basename(url)
29
+
30
+ expected_sha256 = url.split("/")[-2]
31
+ download_target = os.path.join(root, filename)
32
+
33
+ if os.path.exists(download_target) and not os.path.isfile(download_target):
34
+ raise RuntimeError(f"{download_target} exists and is not a regular file")
35
+
36
+ if os.path.isfile(download_target):
37
+ if hashlib.sha256(open(download_target, "rb").read()).hexdigest() == expected_sha256:
38
+ return download_target
39
+ else:
40
+ warnings.warn(f"{download_target} exists, but the SHA256 checksum does not match; re-downloading the file")
41
+
42
+ with urllib.request.urlopen(url) as source, open(download_target, "wb") as output:
43
+ with tqdm(total=int(source.info().get("Content-Length")), ncols=80, unit='iB', unit_scale=True) as loop:
44
+ while True:
45
+ buffer = source.read(8192)
46
+ if not buffer:
47
+ break
48
+
49
+ output.write(buffer)
50
+ loop.update(len(buffer))
51
+
52
+ if hashlib.sha256(open(download_target, "rb").read()).hexdigest() != expected_sha256:
53
+ raise RuntimeError(f"Model has been downloaded but the SHA256 checksum does not not match")
54
+
55
+ return download_target
56
+
57
+
58
+ def _transform(n_px):
59
+ return Compose([
60
+ Resize(n_px, interpolation=Image.BICUBIC),
61
+ CenterCrop(n_px),
62
+ lambda image: image.convert("RGB"),
63
+ ToTensor(),
64
+ Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)),
65
+ ])
66
+
67
+
68
+ def available_models() -> List[str]:
69
+ """Returns the names of available CLIP models"""
70
+ return list(_MODELS.keys())
71
+
72
+
73
+ def load(name: str, device: Union[str, torch.device] = "cuda" if torch.cuda.is_available() else "cpu", jit=True):
74
+ """Load a CLIP model
75
+
76
+ Parameters
77
+ ----------
78
+ name : str
79
+ A model name listed by `clip.available_models()`, or the path to a model checkpoint containing the state_dict
80
+
81
+ device : Union[str, torch.device]
82
+ The device to put the loaded model
83
+
84
+ jit : bool
85
+ Whether to load the optimized JIT model (default) or more hackable non-JIT model.
86
+
87
+ Returns
88
+ -------
89
+ model : torch.nn.Module
90
+ The CLIP model
91
+
92
+ preprocess : Callable[[PIL.Image], torch.Tensor]
93
+ A torchvision transform that converts a PIL image into a tensor that the returned model can take as its input
94
+ """
95
+ if name in _MODELS:
96
+ model_path = _download(_MODELS[name])
97
+ elif os.path.isfile(name):
98
+ model_path = name
99
+ else:
100
+ raise RuntimeError(f"Model {name} not found; available models = {available_models()}")
101
+
102
+ try:
103
+ # loading JIT archive
104
+ model = torch.jit.load(model_path, map_location=device if jit else "cpu").eval()
105
+ state_dict = None
106
+ except RuntimeError:
107
+ # loading saved state dict
108
+ if jit:
109
+ warnings.warn(f"File {model_path} is not a JIT archive. Loading as a state dict instead")
110
+ jit = False
111
+ state_dict = torch.load(model_path, map_location="cpu")
112
+
113
+ if not jit:
114
+ model = build_model(state_dict or model.state_dict()).to(device)
115
+ if str(device) == "cpu":
116
+ model.float()
117
+ return model, _transform(model.visual.input_resolution)
118
+
119
+ # patch the device names
120
+ device_holder = torch.jit.trace(lambda: torch.ones([]).to(torch.device(device)), example_inputs=[])
121
+ device_node = [n for n in device_holder.graph.findAllNodes("prim::Constant") if "Device" in repr(n)][-1]
122
+
123
+ def patch_device(module):
124
+ graphs = [module.graph] if hasattr(module, "graph") else []
125
+ if hasattr(module, "forward1"):
126
+ graphs.append(module.forward1.graph)
127
+
128
+ for graph in graphs:
129
+ for node in graph.findAllNodes("prim::Constant"):
130
+ if "value" in node.attributeNames() and str(node["value"]).startswith("cuda"):
131
+ node.copyAttributes(device_node)
132
+
133
+ model.apply(patch_device)
134
+ patch_device(model.encode_image)
135
+ patch_device(model.encode_text)
136
+
137
+ # patch dtype to float32 on CPU
138
+ if str(device) == "cpu":
139
+ float_holder = torch.jit.trace(lambda: torch.ones([]).float(), example_inputs=[])
140
+ float_input = list(float_holder.graph.findNode("aten::to").inputs())[1]
141
+ float_node = float_input.node()
142
+
143
+ def patch_float(module):
144
+ graphs = [module.graph] if hasattr(module, "graph") else []
145
+ if hasattr(module, "forward1"):
146
+ graphs.append(module.forward1.graph)
147
+
148
+ for graph in graphs:
149
+ for node in graph.findAllNodes("aten::to"):
150
+ inputs = list(node.inputs())
151
+ for i in [1, 2]: # dtype can be the second or third argument to aten::to()
152
+ if inputs[i].node()["value"] == 5:
153
+ inputs[i].node().copyAttributes(float_node)
154
+
155
+ model.apply(patch_float)
156
+ patch_float(model.encode_image)
157
+ patch_float(model.encode_text)
158
+
159
+ model.float()
160
+
161
+ return model, _transform(model.input_resolution.item())
162
+
163
+
164
+ def tokenize(texts: Union[str, List[str]], context_length: int = 77) -> torch.LongTensor:
165
+ """
166
+ Returns the tokenized representation of given input string(s)
167
+
168
+ Parameters
169
+ ----------
170
+ texts : Union[str, List[str]]
171
+ An input string or a list of input strings to tokenize
172
+
173
+ context_length : int
174
+ The context length to use; all CLIP models use 77 as the context length
175
+
176
+ Returns
177
+ -------
178
+ A two-dimensional tensor containing the resulting tokens, shape = [number of input strings, context_length]
179
+ """
180
+ if isinstance(texts, str):
181
+ texts = [texts]
182
+
183
+ sot_token = _tokenizer.encoder["<|startoftext|>"]
184
+ eot_token = _tokenizer.encoder["<|endoftext|>"]
185
+ all_tokens = [[sot_token] + _tokenizer.encode(text) + [eot_token] for text in texts]
186
+ result = torch.zeros(len(all_tokens), context_length, dtype=torch.long)
187
+
188
+ for i, tokens in enumerate(all_tokens):
189
+ if len(tokens) > context_length:
190
+ raise RuntimeError(f"Input {texts[i]} is too long for context length {context_length}")
191
+ result[i, :len(tokens)] = torch.tensor(tokens)
192
+
193
+ return result
clip/model.py ADDED
@@ -0,0 +1,437 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from collections import OrderedDict
2
+ from typing import Tuple, Union
3
+
4
+ import torch
5
+ import torch.nn.functional as F
6
+ from torch import nn
7
+
8
+
9
+ class Bottleneck(nn.Module):
10
+ expansion = 4
11
+
12
+ def __init__(self, inplanes, planes, stride=1):
13
+ super().__init__()
14
+
15
+ # all conv layers have stride 1. an avgpool is performed after the second convolution when stride > 1
16
+ self.conv1 = nn.Conv2d(inplanes, planes, 1, bias=False)
17
+ self.bn1 = nn.BatchNorm2d(planes)
18
+
19
+ self.conv2 = nn.Conv2d(planes, planes, 3, padding=1, bias=False)
20
+ self.bn2 = nn.BatchNorm2d(planes)
21
+
22
+ self.avgpool = nn.AvgPool2d(stride) if stride > 1 else nn.Identity()
23
+
24
+ self.conv3 = nn.Conv2d(planes, planes * self.expansion, 1, bias=False)
25
+ self.bn3 = nn.BatchNorm2d(planes * self.expansion)
26
+
27
+ self.relu = nn.ReLU(inplace=True)
28
+ self.downsample = None
29
+ self.stride = stride
30
+
31
+ if stride > 1 or inplanes != planes * Bottleneck.expansion:
32
+ # downsampling layer is prepended with an avgpool, and the subsequent convolution has stride 1
33
+ self.downsample = nn.Sequential(OrderedDict([
34
+ ("-1", nn.AvgPool2d(stride)),
35
+ ("0", nn.Conv2d(inplanes, planes * self.expansion, 1, stride=1, bias=False)),
36
+ ("1", nn.BatchNorm2d(planes * self.expansion))
37
+ ]))
38
+
39
+ def forward(self, x: torch.Tensor):
40
+ identity = x
41
+
42
+ out = self.relu(self.bn1(self.conv1(x)))
43
+ out = self.relu(self.bn2(self.conv2(out)))
44
+ out = self.avgpool(out)
45
+ out = self.bn3(self.conv3(out))
46
+
47
+ if self.downsample is not None:
48
+ identity = self.downsample(x)
49
+
50
+ out += identity
51
+ out = self.relu(out)
52
+ return out
53
+
54
+
55
+ class AttentionPool2d(nn.Module):
56
+ def __init__(self, spacial_dim: int, embed_dim: int, num_heads: int, output_dim: int = None):
57
+ super().__init__()
58
+ self.positional_embedding = nn.Parameter(torch.randn(spacial_dim ** 2 + 1, embed_dim) / embed_dim ** 0.5)
59
+ self.k_proj = nn.Linear(embed_dim, embed_dim)
60
+ self.q_proj = nn.Linear(embed_dim, embed_dim)
61
+ self.v_proj = nn.Linear(embed_dim, embed_dim)
62
+ self.c_proj = nn.Linear(embed_dim, output_dim or embed_dim)
63
+ self.num_heads = num_heads
64
+
65
+ def forward(self, x):
66
+ x = x.reshape(x.shape[0], x.shape[1], x.shape[2] * x.shape[3]).permute(2, 0, 1) # NCHW -> (HW)NC
67
+ # print(x.shape, self.positional_embedding.shape)
68
+ x = torch.cat([x.mean(dim=0, keepdim=True), x], dim=0) # (HW+1)NC
69
+ x = x + self.positional_embedding[0, :, None, :].to(x.dtype) # (HW+1)NC
70
+ x, _ = F.multi_head_attention_forward(
71
+ query=x, key=x, value=x,
72
+ embed_dim_to_check=x.shape[-1],
73
+ num_heads=self.num_heads,
74
+ q_proj_weight=self.q_proj.weight,
75
+ k_proj_weight=self.k_proj.weight,
76
+ v_proj_weight=self.v_proj.weight,
77
+ in_proj_weight=None,
78
+ in_proj_bias=torch.cat([self.q_proj.bias, self.k_proj.bias, self.v_proj.bias]),
79
+ bias_k=None,
80
+ bias_v=None,
81
+ add_zero_attn=False,
82
+ dropout_p=0,
83
+ out_proj_weight=torch.ones_like(self.q_proj.weight),
84
+ out_proj_bias=torch.zeros_like(self.q_proj.bias),
85
+ # out_proj_weight=self.c_proj.weight,
86
+ # out_proj_bias=self.c_proj.bias,
87
+ use_separate_proj_weight=True,
88
+ training=self.training,
89
+ need_weights=False
90
+ )
91
+
92
+ return x[0]
93
+
94
+
95
+ class ModifiedResNet(nn.Module):
96
+ """
97
+ A ResNet class that is similar to torchvision's but contains the following changes:
98
+ - There are now 3 "stem" convolutions as opposed to 1, with an average pool instead of a max pool.
99
+ - Performs anti-aliasing strided convolutions, where an avgpool is prepended to convolutions with stride > 1
100
+ - The final pooling layer is a QKV attention instead of an average pool
101
+ """
102
+
103
+ def __init__(self, layers, output_dim, heads, input_resolution=224, width=64):
104
+ super().__init__()
105
+ self.output_dim = output_dim
106
+ self.input_resolution = input_resolution
107
+
108
+ # the 3-layer stem
109
+ self.conv1 = nn.Conv2d(3, width // 2, kernel_size=3, stride=2, padding=1, bias=False)
110
+ self.bn1 = nn.BatchNorm2d(width // 2)
111
+ self.conv2 = nn.Conv2d(width // 2, width // 2, kernel_size=3, padding=1, bias=False)
112
+ self.bn2 = nn.BatchNorm2d(width // 2)
113
+ self.conv3 = nn.Conv2d(width // 2, width, kernel_size=3, padding=1, bias=False)
114
+ self.bn3 = nn.BatchNorm2d(width)
115
+ self.avgpool = nn.AvgPool2d(2)
116
+ self.relu = nn.ReLU(inplace=True)
117
+
118
+ # residual layers
119
+ self._inplanes = width # this is a *mutable* variable used during construction
120
+ self.layer1 = self._make_layer(width, layers[0])
121
+ self.layer2 = self._make_layer(width * 2, layers[1], stride=2)
122
+ self.layer3 = self._make_layer(width * 4, layers[2], stride=2)
123
+ self.layer4 = self._make_layer(width * 8, layers[3], stride=2)
124
+
125
+ embed_dim = width * 32 # the ResNet feature dimension
126
+ self.attnpool = AttentionPool2d(input_resolution // 32, embed_dim, heads, output_dim)
127
+
128
+ def _make_layer(self, planes, blocks, stride=1):
129
+ layers = [Bottleneck(self._inplanes, planes, stride)]
130
+
131
+ self._inplanes = planes * Bottleneck.expansion
132
+ for _ in range(1, blocks):
133
+ layers.append(Bottleneck(self._inplanes, planes))
134
+
135
+ return nn.Sequential(*layers)
136
+
137
+ def forward(self, x):
138
+ def stem(x):
139
+ for conv, bn in [(self.conv1, self.bn1), (self.conv2, self.bn2), (self.conv3, self.bn3)]:
140
+ x = self.relu(bn(conv(x)))
141
+ x = self.avgpool(x)
142
+ return x
143
+
144
+ x = x.type(self.conv1.weight.dtype)
145
+ x = stem(x)
146
+ x = self.layer1(x)
147
+ x = self.layer2(x)
148
+ x = self.layer3(x)
149
+ x = self.layer4(x)
150
+ # print(x.shape)
151
+ # x = self.attnpool(x)
152
+ attnpool = self.attnpool(x)
153
+
154
+ return (x, attnpool)
155
+
156
+
157
+ class LayerNorm(nn.LayerNorm):
158
+ """Subclass torch's LayerNorm to handle fp16."""
159
+
160
+ def forward(self, x: torch.Tensor):
161
+ orig_type = x.dtype
162
+ ret = super().forward(x.type(torch.float32))
163
+ return ret.type(orig_type)
164
+
165
+
166
+ class QuickGELU(nn.Module):
167
+ def forward(self, x: torch.Tensor):
168
+ return x * torch.sigmoid(1.702 * x)
169
+
170
+
171
+ class ResidualAttentionBlock(nn.Module):
172
+ def __init__(self, d_model: int, n_head: int, attn_mask: torch.Tensor = None):
173
+ super().__init__()
174
+
175
+ self.attn = nn.MultiheadAttention(d_model, n_head)
176
+ self.ln_1 = LayerNorm(d_model)
177
+ self.mlp = nn.Sequential(OrderedDict([
178
+ ("c_fc", nn.Linear(d_model, d_model * 4)),
179
+ ("gelu", QuickGELU()),
180
+ ("c_proj", nn.Linear(d_model * 4, d_model))
181
+ ]))
182
+ self.ln_2 = LayerNorm(d_model)
183
+ self.attn_mask = attn_mask
184
+
185
+ def attention(self, x: torch.Tensor):
186
+ self.attn_mask = self.attn_mask.to(dtype=x.dtype, device=x.device) if self.attn_mask is not None else None
187
+ return self.attn(x, x, x, need_weights=False, attn_mask=self.attn_mask)[0]
188
+
189
+ def forward(self, x: torch.Tensor):
190
+ x = x + self.attention(self.ln_1(x))
191
+ x = x + self.mlp(self.ln_2(x))
192
+ return x
193
+
194
+
195
+ class Transformer(nn.Module):
196
+ def __init__(self, width: int, layers: int, heads: int, attn_mask: torch.Tensor = None):
197
+ super().__init__()
198
+ self.width = width
199
+ self.layers = layers
200
+ self.resblocks = nn.Sequential(*[ResidualAttentionBlock(width, heads, attn_mask) for _ in range(layers)])
201
+
202
+ def forward(self, x: torch.Tensor):
203
+ return self.resblocks(x)
204
+
205
+
206
+ class VisualTransformer(nn.Module):
207
+ def __init__(self, input_resolution: int, patch_size: int, width: int, layers: int, heads: int, output_dim: int):
208
+ super().__init__()
209
+ self.input_resolution = input_resolution
210
+ self.output_dim = output_dim
211
+ self.conv1 = nn.Conv2d(in_channels=3, out_channels=width, kernel_size=patch_size, stride=patch_size, bias=False)
212
+
213
+ scale = width ** -0.5
214
+ self.class_embedding = nn.Parameter(scale * torch.randn(width))
215
+ self.positional_embedding = nn.Parameter(scale * torch.randn((input_resolution // patch_size) ** 2 + 1, width))
216
+ self.ln_pre = LayerNorm(width)
217
+
218
+ self.transformer = Transformer(width, layers, heads)
219
+
220
+ self.ln_post = LayerNorm(width)
221
+ self.proj = nn.Parameter(scale * torch.randn(width, output_dim))
222
+
223
+ def forward(self, x: torch.Tensor):
224
+ x = self.conv1(x) # shape = [*, width, grid, grid]
225
+ x = x.reshape(x.shape[0], x.shape[1], -1) # shape = [*, width, grid ** 2]
226
+ x = x.permute(0, 2, 1) # shape = [*, grid ** 2, width]
227
+ x = torch.cat([self.class_embedding.to(x.dtype) + torch.zeros(x.shape[0], 1, x.shape[-1], dtype=x.dtype, device=x.device), x], dim=1) # shape = [*, grid ** 2 + 1, width]
228
+ x = x + self.positional_embedding.to(x.dtype)
229
+ x = self.ln_pre(x)
230
+
231
+ x = x.permute(1, 0, 2) # NLD -> LND
232
+ x = self.transformer(x)
233
+ x = x.permute(1, 0, 2) # LND -> NLD
234
+
235
+ # x = self.ln_post(x[:, 0, :])
236
+
237
+ x = self.ln_post(x)
238
+ # if self.proj is not None:
239
+ # x = x @ self.proj
240
+
241
+ return x
242
+
243
+
244
+ class CLIP(nn.Module):
245
+ def __init__(self,
246
+ embed_dim: int,
247
+ # vision
248
+ image_resolution: int,
249
+ vision_layers: Union[Tuple[int, int, int, int], int],
250
+ vision_width: int,
251
+ vision_patch_size: int,
252
+ # text
253
+ context_length: int,
254
+ vocab_size: int,
255
+ transformer_width: int,
256
+ transformer_heads: int,
257
+ transformer_layers: int
258
+ ):
259
+ super().__init__()
260
+
261
+ self.context_length = context_length
262
+
263
+ if isinstance(vision_layers, (tuple, list)):
264
+ vision_heads = vision_width * 32 // 64
265
+ self.visual = ModifiedResNet(
266
+ layers=vision_layers,
267
+ output_dim=embed_dim,
268
+ heads=vision_heads,
269
+ input_resolution=image_resolution,
270
+ width=vision_width
271
+ )
272
+ else:
273
+ vision_heads = vision_width // 64
274
+ self.visual = VisualTransformer(
275
+ input_resolution=image_resolution,
276
+ patch_size=vision_patch_size,
277
+ width=vision_width,
278
+ layers=vision_layers,
279
+ heads=vision_heads,
280
+ output_dim=embed_dim
281
+ )
282
+
283
+ self.transformer = Transformer(
284
+ width=transformer_width,
285
+ layers=transformer_layers,
286
+ heads=transformer_heads,
287
+ attn_mask=self.build_attention_mask()
288
+ )
289
+
290
+ self.vocab_size = vocab_size
291
+ self.token_embedding = nn.Embedding(vocab_size, transformer_width)
292
+ self.positional_embedding = nn.Parameter(torch.empty(self.context_length, transformer_width))
293
+ self.ln_final = LayerNorm(transformer_width)
294
+
295
+ self.text_projection = nn.Parameter(torch.empty(transformer_width, embed_dim))
296
+ self.logit_scale = nn.Parameter(torch.ones([]))
297
+
298
+ self.initialize_parameters()
299
+
300
+ def initialize_parameters(self):
301
+ nn.init.normal_(self.token_embedding.weight, std=0.02)
302
+ nn.init.normal_(self.positional_embedding, std=0.01)
303
+
304
+ if isinstance(self.visual, ModifiedResNet):
305
+ if self.visual.attnpool is not None:
306
+ std = self.visual.attnpool.c_proj.in_features ** -0.5
307
+ nn.init.normal_(self.visual.attnpool.q_proj.weight, std=std)
308
+ nn.init.normal_(self.visual.attnpool.k_proj.weight, std=std)
309
+ nn.init.normal_(self.visual.attnpool.v_proj.weight, std=std)
310
+ nn.init.normal_(self.visual.attnpool.c_proj.weight, std=std)
311
+
312
+ for resnet_block in [self.visual.layer1, self.visual.layer2, self.visual.layer3, self.visual.layer4]:
313
+ for name, param in resnet_block.named_parameters():
314
+ if name.endswith("bn3.weight"):
315
+ nn.init.zeros_(param)
316
+
317
+ proj_std = (self.transformer.width ** -0.5) * ((2 * self.transformer.layers) ** -0.5)
318
+ attn_std = self.transformer.width ** -0.5
319
+ fc_std = (2 * self.transformer.width) ** -0.5
320
+ for block in self.transformer.resblocks:
321
+ nn.init.normal_(block.attn.in_proj_weight, std=attn_std)
322
+ nn.init.normal_(block.attn.out_proj.weight, std=proj_std)
323
+ nn.init.normal_(block.mlp.c_fc.weight, std=fc_std)
324
+ nn.init.normal_(block.mlp.c_proj.weight, std=proj_std)
325
+
326
+ if self.text_projection is not None:
327
+ nn.init.normal_(self.text_projection, std=self.transformer.width ** -0.5)
328
+
329
+ def build_attention_mask(self):
330
+ # lazily create causal attention mask, with full attention between the vision tokens
331
+ # pytorch uses additive attention mask; fill with -inf
332
+ mask = torch.empty(self.context_length, self.context_length)
333
+ mask.fill_(float("-inf"))
334
+ mask.triu_(1) # zero out the lower diagonal
335
+ return mask
336
+
337
+ @property
338
+ def dtype(self):
339
+ return self.visual.conv1.weight.dtype
340
+
341
+ def encode_image(self, image):
342
+ return self.visual(image.type(self.dtype))
343
+
344
+ def encode_text(self, text):
345
+ x = self.token_embedding(text).type(self.dtype) # [batch_size, n_ctx, d_model]
346
+
347
+ x = x + self.positional_embedding.type(self.dtype)
348
+ x = x.permute(1, 0, 2) # NLD -> LND
349
+ x = self.transformer(x)
350
+ x = x.permute(1, 0, 2) # LND -> NLD
351
+ x = self.ln_final(x).type(self.dtype)
352
+
353
+ # x.shape = [batch_size, n_ctx, transformer.width]
354
+ # take features from the eot embedding (eot_token is the highest number in each sequence)
355
+ x = x[torch.arange(x.shape[0]), text.argmax(dim=-1)] @ self.text_projection
356
+
357
+ return x
358
+
359
+ def forward(self, image, text):
360
+ image_features = self.encode_image(image)
361
+ text_features = self.encode_text(text)
362
+
363
+ # normalized features
364
+ image_features = image_features / image_features.norm(dim=-1, keepdim=True)
365
+ text_features = text_features / text_features.norm(dim=-1, keepdim=True)
366
+
367
+ # cosine similarity as logits
368
+ logit_scale = self.logit_scale.exp()
369
+ logits_per_image = logit_scale * image_features @ text_features.t()
370
+ logits_per_text = logit_scale * text_features @ image_features.t()
371
+
372
+ # shape = [global_batch_size, global_batch_size]
373
+ return logits_per_image, logits_per_text
374
+
375
+
376
+ def convert_weights(model: nn.Module):
377
+ """Convert applicable model parameters to fp16"""
378
+
379
+ def _convert_weights_to_fp16(l):
380
+ if isinstance(l, (nn.Conv1d, nn.Conv2d, nn.Linear)):
381
+ l.weight.data = l.weight.data.half()
382
+ if l.bias is not None:
383
+ l.bias.data = l.bias.data.half()
384
+
385
+ if isinstance(l, nn.MultiheadAttention):
386
+ for attr in [*[f"{s}_proj_weight" for s in ["in", "q", "k", "v"]], "in_proj_bias", "bias_k", "bias_v"]:
387
+ tensor = getattr(l, attr)
388
+ if tensor is not None:
389
+ tensor.data = tensor.data.half()
390
+
391
+ for name in ["text_projection", "proj"]:
392
+ if hasattr(l, name):
393
+ attr = getattr(l, name)
394
+ if attr is not None:
395
+ attr.data = attr.data.half()
396
+
397
+ model.apply(_convert_weights_to_fp16)
398
+
399
+
400
+ def build_model(state_dict: dict):
401
+ vit = "visual.proj" in state_dict
402
+
403
+ if vit:
404
+ vision_width = state_dict["visual.conv1.weight"].shape[0]
405
+ vision_layers = len([k for k in state_dict.keys() if k.startswith("visual.") and k.endswith(".attn.in_proj_weight")])
406
+ vision_patch_size = state_dict["visual.conv1.weight"].shape[-1]
407
+ grid_size = round((state_dict["visual.positional_embedding"].shape[0] - 1) ** 0.5)
408
+ image_resolution = vision_patch_size * grid_size
409
+ else:
410
+ counts: list = [len(set(k.split(".")[2] for k in state_dict if k.startswith(f"visual.layer{b}"))) for b in [1, 2, 3, 4]]
411
+ vision_layers = tuple(counts)
412
+ vision_width = state_dict["visual.layer1.0.conv1.weight"].shape[0]
413
+ output_width = round((state_dict["visual.attnpool.positional_embedding"].shape[0] - 1) ** 0.5)
414
+ vision_patch_size = None
415
+ assert output_width ** 2 + 1 == state_dict["visual.attnpool.positional_embedding"].shape[0]
416
+ image_resolution = output_width * 32
417
+
418
+ embed_dim = state_dict["text_projection"].shape[1]
419
+ context_length = state_dict["positional_embedding"].shape[0]
420
+ vocab_size = state_dict["token_embedding.weight"].shape[0]
421
+ transformer_width = state_dict["ln_final.weight"].shape[0]
422
+ transformer_heads = transformer_width // 64
423
+ transformer_layers = len(set(k.split(".")[2] for k in state_dict if k.startswith(f"transformer.resblocks")))
424
+
425
+ model = CLIP(
426
+ embed_dim,
427
+ image_resolution, vision_layers, vision_width, vision_patch_size,
428
+ context_length, vocab_size, transformer_width, transformer_heads, transformer_layers
429
+ )
430
+
431
+ for key in ["input_resolution", "context_length", "vocab_size"]:
432
+ if key in state_dict:
433
+ del state_dict[key]
434
+
435
+ convert_weights(model)
436
+ model.load_state_dict(state_dict)
437
+ return model.eval()
clip/simple_tokenizer.py ADDED
@@ -0,0 +1,132 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gzip
2
+ import html
3
+ import os
4
+ from functools import lru_cache
5
+
6
+ import ftfy
7
+ import regex as re
8
+
9
+
10
+ @lru_cache()
11
+ def default_bpe():
12
+ return os.path.join(os.path.dirname(os.path.abspath(__file__)), "bpe_simple_vocab_16e6.txt.gz")
13
+
14
+
15
+ @lru_cache()
16
+ def bytes_to_unicode():
17
+ """
18
+ Returns list of utf-8 byte and a corresponding list of unicode strings.
19
+ The reversible bpe codes work on unicode strings.
20
+ This means you need a large # of unicode characters in your vocab if you want to avoid UNKs.
21
+ When you're at something like a 10B token dataset you end up needing around 5K for decent coverage.
22
+ This is a signficant percentage of your normal, say, 32K bpe vocab.
23
+ To avoid that, we want lookup tables between utf-8 bytes and unicode strings.
24
+ And avoids mapping to whitespace/control characters the bpe code barfs on.
25
+ """
26
+ bs = list(range(ord("!"), ord("~")+1))+list(range(ord("¡"), ord("¬")+1))+list(range(ord("®"), ord("ÿ")+1))
27
+ cs = bs[:]
28
+ n = 0
29
+ for b in range(2**8):
30
+ if b not in bs:
31
+ bs.append(b)
32
+ cs.append(2**8+n)
33
+ n += 1
34
+ cs = [chr(n) for n in cs]
35
+ return dict(zip(bs, cs))
36
+
37
+
38
+ def get_pairs(word):
39
+ """Return set of symbol pairs in a word.
40
+ Word is represented as tuple of symbols (symbols being variable-length strings).
41
+ """
42
+ pairs = set()
43
+ prev_char = word[0]
44
+ for char in word[1:]:
45
+ pairs.add((prev_char, char))
46
+ prev_char = char
47
+ return pairs
48
+
49
+
50
+ def basic_clean(text):
51
+ text = ftfy.fix_text(text)
52
+ text = html.unescape(html.unescape(text))
53
+ return text.strip()
54
+
55
+
56
+ def whitespace_clean(text):
57
+ text = re.sub(r'\s+', ' ', text)
58
+ text = text.strip()
59
+ return text
60
+
61
+
62
+ class SimpleTokenizer(object):
63
+ def __init__(self, bpe_path: str = default_bpe()):
64
+ self.byte_encoder = bytes_to_unicode()
65
+ self.byte_decoder = {v: k for k, v in self.byte_encoder.items()}
66
+ merges = gzip.open(bpe_path).read().decode("utf-8").split('\n')
67
+ merges = merges[1:49152-256-2+1]
68
+ merges = [tuple(merge.split()) for merge in merges]
69
+ vocab = list(bytes_to_unicode().values())
70
+ vocab = vocab + [v+'</w>' for v in vocab]
71
+ for merge in merges:
72
+ vocab.append(''.join(merge))
73
+ vocab.extend(['<|startoftext|>', '<|endoftext|>'])
74
+ self.encoder = dict(zip(vocab, range(len(vocab))))
75
+ self.decoder = {v: k for k, v in self.encoder.items()}
76
+ self.bpe_ranks = dict(zip(merges, range(len(merges))))
77
+ self.cache = {'<|startoftext|>': '<|startoftext|>', '<|endoftext|>': '<|endoftext|>'}
78
+ self.pat = re.compile(r"""<\|startoftext\|>|<\|endoftext\|>|'s|'t|'re|'ve|'m|'ll|'d|[\p{L}]+|[\p{N}]|[^\s\p{L}\p{N}]+""", re.IGNORECASE)
79
+
80
+ def bpe(self, token):
81
+ if token in self.cache:
82
+ return self.cache[token]
83
+ word = tuple(token[:-1]) + ( token[-1] + '</w>',)
84
+ pairs = get_pairs(word)
85
+
86
+ if not pairs:
87
+ return token+'</w>'
88
+
89
+ while True:
90
+ bigram = min(pairs, key = lambda pair: self.bpe_ranks.get(pair, float('inf')))
91
+ if bigram not in self.bpe_ranks:
92
+ break
93
+ first, second = bigram
94
+ new_word = []
95
+ i = 0
96
+ while i < len(word):
97
+ try:
98
+ j = word.index(first, i)
99
+ new_word.extend(word[i:j])
100
+ i = j
101
+ except:
102
+ new_word.extend(word[i:])
103
+ break
104
+
105
+ if word[i] == first and i < len(word)-1 and word[i+1] == second:
106
+ new_word.append(first+second)
107
+ i += 2
108
+ else:
109
+ new_word.append(word[i])
110
+ i += 1
111
+ new_word = tuple(new_word)
112
+ word = new_word
113
+ if len(word) == 1:
114
+ break
115
+ else:
116
+ pairs = get_pairs(word)
117
+ word = ' '.join(word)
118
+ self.cache[token] = word
119
+ return word
120
+
121
+ def encode(self, text):
122
+ bpe_tokens = []
123
+ text = whitespace_clean(basic_clean(text)).lower()
124
+ for token in re.findall(self.pat, text):
125
+ token = ''.join(self.byte_encoder[b] for b in token.encode('utf-8'))
126
+ bpe_tokens.extend(self.encoder[bpe_token] for bpe_token in self.bpe(token).split(' '))
127
+ return bpe_tokens
128
+
129
+ def decode(self, tokens):
130
+ text = ''.join([self.decoder[token] for token in tokens])
131
+ text = bytearray([self.byte_decoder[c] for c in text]).decode('utf-8', errors="replace").replace('</w>', ' ')
132
+ return text
cog.yaml ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ build:
2
+ cuda: "11.0"
3
+ gpu: true
4
+ python_version: "3.7"
5
+ system_packages:
6
+ - "libgl1-mesa-glx"
7
+ - "libglib2.0-0"
8
+ python_packages:
9
+ - "ipython==7.21.0"
10
+ - "transformers==4.19.2"
11
+ - "h5py==3.7.0"
12
+ - "numpy==1.20.3"
13
+ - "pandas==1.3.3"
14
+ - "scikit-image==0.18.3"
15
+ - "ipywidgets==7.7.0"
16
+ - "wandb==0.12.17"
17
+ - "bert-score==0.3.11"
18
+ - "ftfy==6.1.1"
19
+ - "timm==0.5.4"
20
+ - "lmdbdict==0.2.2"
21
+ - "yacs==0.1.8"
22
+ - "pyemd==0.5.1"
23
+ - "gensim==4.2.0"
24
+ - "pytorch-lightning==1.6.3"
25
+
26
+ predict: "predict.py:Predictor"
configs/phase1/FineCapEval_clipRN50_mle.yml ADDED
@@ -0,0 +1,60 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ caption_model: transformer
2
+ noamopt: true
3
+ noamopt_warmup: 20000
4
+ label_smoothing: 0.0
5
+ input_json: data/FineCapEval.json
6
+ input_label_h5: none
7
+ input_fc_dir: data/FineCapEval_clip_RN50_fc
8
+ input_att_dir: data/FineCapEval_clip_RN50_att
9
+ input_clipscore_vis_dir: data/FineCapEval_clipscore_vis
10
+
11
+ seq_per_img: 5
12
+ batch_size: 200
13
+ learning_rate: 0.0005
14
+
15
+ checkpoint_path: ./save/clipRN50_mle/clipRN50_mle
16
+
17
+ # clip_load_path: '/scratch-space/retrieval/save/clip_negative_text/clip_negative_text-epoch=10.ckpt'
18
+
19
+ # Notice: because I'm to lazy, I reuse the option name for RNNs to set the hyperparameters for transformer:
20
+ # N=num_layers
21
+ # d_model=input_encoding_size
22
+ # d_ff=rnn_size
23
+
24
+ # will be ignored
25
+ num_layers: 6
26
+ input_encoding_size: 512
27
+ rnn_size: 2048
28
+
29
+ # Transformer config
30
+ N_enc: 6
31
+ N_dec: 6
32
+ d_model: 512
33
+ d_ff: 2048
34
+ num_att_heads: 8
35
+ dropout: 0.1
36
+
37
+
38
+ learning_rate_decay_start: 0
39
+ scheduled_sampling_start: -1
40
+ save_checkpoint_every: 3000
41
+ language_eval: 1
42
+ val_images_use: 5000
43
+ max_epochs: 15
44
+ train_sample_n: 5
45
+
46
+ REFORWARD: false
47
+
48
+ # _BASE_: transformer.yml
49
+ reduce_on_plateau: false
50
+ noamopt: false
51
+ learning_rate: 0.000005
52
+ learning_rate_decay_start: -1
53
+
54
+ self_critical_after: 15
55
+ max_epochs: 50
56
+
57
+ verbose: false
58
+ precision: 32
59
+
60
+ use_clipscore: false
configs/phase1/clipRN50_mle.yml ADDED
@@ -0,0 +1,52 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ caption_model: transformer
2
+ noamopt: true
3
+ # noamopt: false
4
+ noamopt_warmup: 20000
5
+ label_smoothing: 0.0
6
+ input_json: data/cocotalk.json
7
+ input_label_h5: data/cocotalk_label.h5
8
+ input_fc_dir: data/cocotalk_clip_RN50_fc
9
+ input_att_dir: data/cocotalk_clip_RN50_att
10
+ input_clipscore_vis_dir: data/cocotalk_clipscore_vis
11
+ seq_per_img: 5
12
+ # batch_size: 600
13
+ batch_size: 200
14
+
15
+ learning_rate: 0.0005
16
+
17
+ # checkpoint_path: ./save/trans_clip_rn50_sc_pl
18
+ checkpoint_path: save/clipRN50_mle/clipRN50_mle
19
+
20
+ # Notice: because I'm to lazy, I reuse the option name for RNNs to set the hyperparameters for transformer:
21
+ # N=num_layers
22
+ # d_model=input_encoding_size
23
+ # d_ff=rnn_size
24
+
25
+ # will be ignored
26
+ num_layers: 6
27
+ input_encoding_size: 512
28
+ rnn_size: 2048
29
+
30
+ # Transformer config
31
+ N_enc: 6
32
+ N_dec: 6
33
+ d_model: 512
34
+ d_ff: 2048
35
+ num_att_heads: 8
36
+ dropout: 0.1
37
+
38
+
39
+ learning_rate_decay_start: 0
40
+ scheduled_sampling_start: -1
41
+ save_checkpoint_every: 3000
42
+ language_eval: 1
43
+ val_images_use: 5000
44
+ # max_epochs: 15
45
+ max_epochs: 25
46
+ train_sample_n: 5
47
+
48
+ REFORWARD: false
49
+
50
+
51
+ verbose: false
52
+ precision: 16
configs/phase1/transformer.yml ADDED
@@ -0,0 +1,41 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ caption_model: transformer
2
+ noamopt: true
3
+ noamopt_warmup: 20000
4
+ label_smoothing: 0.0
5
+ input_json: data/cocotalk.json
6
+ input_label_h5: data/cocotalk_label.h5
7
+ input_att_dir: data/cocotalk_att
8
+ seq_per_img: 5
9
+ batch_size: 10
10
+ learning_rate: 0.0005
11
+
12
+ checkpoint_path: ./save/trans_rn50_sc
13
+
14
+ # Notice: because I'm to lazy, I reuse the option name for RNNs to set the hyperparameters for transformer:
15
+ # N=num_layers
16
+ # d_model=input_encoding_size
17
+ # d_ff=rnn_size
18
+
19
+ # will be ignored
20
+ num_layers: 6
21
+ input_encoding_size: 512
22
+ rnn_size: 2048
23
+
24
+ # Transformer config
25
+ N_enc: 6
26
+ N_dec: 6
27
+ d_model: 512
28
+ d_ff: 2048
29
+ num_att_heads: 8
30
+ dropout: 0.1
31
+
32
+
33
+ learning_rate_decay_start: 0
34
+ scheduled_sampling_start: -1
35
+ save_checkpoint_every: 3000
36
+ language_eval: 1
37
+ val_images_use: 5000
38
+ max_epochs: 15
39
+ train_sample_n: 5
40
+
41
+ REFORWARD: false
configs/phase2/FineCapEval_clipRN50_cider.yml ADDED
@@ -0,0 +1,61 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ caption_model: transformer
2
+ noamopt: true
3
+ noamopt_warmup: 20000
4
+ label_smoothing: 0.0
5
+ input_json: data/FineCapEval.json
6
+ input_label_h5: none
7
+ input_fc_dir: data/FineCapEval_clip_RN50_fc
8
+ input_att_dir: data/FineCapEval_clip_RN50_att
9
+ input_clipscore_vis_dir: data/FineCapEval_clipscore_vis
10
+
11
+ seq_per_img: 5
12
+ batch_size: 200
13
+ learning_rate: 0.0005
14
+
15
+ checkpoint_path: ./save/clipRN50_cider/clipRN50_cider
16
+
17
+ # clip_load_path: '/scratch-space/retrieval/save/clip_negative_text/clip_negative_text-epoch=10.ckpt'
18
+
19
+ # Notice: because I'm to lazy, I reuse the option name for RNNs to set the hyperparameters for transformer:
20
+ # N=num_layers
21
+ # d_model=input_encoding_size
22
+ # d_ff=rnn_size
23
+
24
+ # will be ignored
25
+ num_layers: 6
26
+ input_encoding_size: 512
27
+ rnn_size: 2048
28
+
29
+ # Transformer config
30
+ N_enc: 6
31
+ N_dec: 6
32
+ d_model: 512
33
+ d_ff: 2048
34
+ num_att_heads: 8
35
+ dropout: 0.1
36
+
37
+
38
+ learning_rate_decay_start: 0
39
+ scheduled_sampling_start: -1
40
+ save_checkpoint_every: 3000
41
+ language_eval: 1
42
+ val_images_use: 5000
43
+ max_epochs: 15
44
+ train_sample_n: 5
45
+
46
+ REFORWARD: false
47
+
48
+ # _BASE_: transformer.yml
49
+ reduce_on_plateau: false
50
+ noamopt: false
51
+ learning_rate: 0.000005
52
+ learning_rate_decay_start: -1
53
+
54
+ self_critical_after: 15
55
+ max_epochs: 50
56
+
57
+ verbose: false
58
+ precision: 32
59
+
60
+ # use_clipscore: true
61
+ use_clipscore: false
configs/phase2/FineCapEval_clipRN50_cider_clips.yml ADDED
@@ -0,0 +1,65 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ caption_model: transformer
2
+ noamopt: true
3
+ noamopt_warmup: 20000
4
+ label_smoothing: 0.0
5
+ input_json: data/FineCapEval.json
6
+ input_label_h5: none
7
+ input_fc_dir: data/FineCapEval_clip_RN50_fc
8
+ input_att_dir: data/FineCapEval_clip_RN50_att
9
+ input_clipscore_vis_dir: data/FineCapEval_clipscore_vis
10
+
11
+ seq_per_img: 5
12
+ batch_size: 200
13
+ learning_rate: 0.0005
14
+
15
+ checkpoint_path: ./save/clipRN50_cider_clips/clipRN50_cider_clips
16
+
17
+ # clip_load_path: '/scratch-space/retrieval/save/clip_negative_text/clip_negative_text-epoch=10.ckpt'
18
+
19
+ # Notice: because I'm to lazy, I reuse the option name for RNNs to set the hyperparameters for transformer:
20
+ # N=num_layers
21
+ # d_model=input_encoding_size
22
+ # d_ff=rnn_size
23
+
24
+ # will be ignored
25
+ num_layers: 6
26
+ input_encoding_size: 512
27
+ rnn_size: 2048
28
+
29
+ # Transformer config
30
+ N_enc: 6
31
+ N_dec: 6
32
+ d_model: 512
33
+ d_ff: 2048
34
+ num_att_heads: 8
35
+ dropout: 0.1
36
+
37
+
38
+ learning_rate_decay_start: 0
39
+ scheduled_sampling_start: -1
40
+ save_checkpoint_every: 3000
41
+ language_eval: 1
42
+ val_images_use: 5000
43
+ max_epochs: 15
44
+ train_sample_n: 5
45
+
46
+ REFORWARD: false
47
+
48
+ # _BASE_: transformer.yml
49
+ reduce_on_plateau: false
50
+ noamopt: false
51
+ learning_rate: 0.000005
52
+ learning_rate_decay_start: -1
53
+
54
+ self_critical_after: 15
55
+ max_epochs: 50
56
+
57
+ verbose: false
58
+ precision: 32
59
+
60
+ # use_clipscore: true
61
+ use_clipscore: false
62
+ clipscore_reward_weight: 2.0
63
+ clipscore_mode: clip_s
64
+
65
+ use_multi_rewards: true
configs/phase2/FineCapEval_clipRN50_clips.yml ADDED
@@ -0,0 +1,64 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ caption_model: transformer
2
+ noamopt: true
3
+ noamopt_warmup: 20000
4
+ label_smoothing: 0.0
5
+ input_json: data/FineCapEval.json
6
+ input_label_h5: none
7
+ input_fc_dir: data/FineCapEval_clip_RN50_fc
8
+ input_att_dir: data/FineCapEval_clip_RN50_att
9
+ input_clipscore_vis_dir: data/FineCapEval_clipscore_vis
10
+ seq_per_img: 5
11
+ batch_size: 160
12
+ learning_rate: 0.0005
13
+
14
+ checkpoint_path: ./save/clipRN50_clips/clipRN50_clips
15
+
16
+ use_multi_rewards: false
17
+ use_grammar: false
18
+ use_grammar_baseline: false
19
+ # clip_load_path: '/scratch-space/retrieval/save/clip_negative_text/clip_negative_text-epoch=10.ckpt'
20
+
21
+ # Notice: because I'm to lazy, I reuse the option name for RNNs to set the hyperparameters for transformer:
22
+ # N=num_layers
23
+ # d_model=input_encoding_size
24
+ # d_ff=rnn_size
25
+
26
+ # will be ignored
27
+ num_layers: 6
28
+ input_encoding_size: 512
29
+ rnn_size: 2048
30
+
31
+ # Transformer config
32
+ N_enc: 6
33
+ N_dec: 6
34
+ d_model: 512
35
+ d_ff: 2048
36
+ num_att_heads: 8
37
+ dropout: 0.1
38
+
39
+
40
+ learning_rate_decay_start: 0
41
+ scheduled_sampling_start: -1
42
+ save_checkpoint_every: 3000
43
+ language_eval: 0
44
+ val_images_use: 5000
45
+ max_epochs: 15
46
+ train_sample_n: 5
47
+
48
+ REFORWARD: false
49
+
50
+ # _BASE_: transformer.yml
51
+ reduce_on_plateau: false
52
+ noamopt: false
53
+ learning_rate: 0.000005
54
+ learning_rate_decay_start: -1
55
+
56
+ self_critical_after: 15
57
+ max_epochs: 50
58
+
59
+ verbose: false
60
+ precision: 32
61
+
62
+ # use_clipscore: true
63
+ use_clipscore: false
64
+ clipscore_reward_weight: 2.0
configs/phase2/FineCapEval_clipRN50_clips_grammar.yml ADDED
@@ -0,0 +1,64 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ caption_model: transformer
2
+ noamopt: true
3
+ noamopt_warmup: 20000
4
+ label_smoothing: 0.0
5
+ input_json: data/FineCapEval.json
6
+ input_label_h5: none
7
+ input_fc_dir: data/FineCapEval_clip_RN50_fc
8
+ input_att_dir: data/FineCapEval_clip_RN50_att
9
+ input_clipscore_vis_dir: data/FineCapEval_clipscore_vis
10
+ seq_per_img: 5
11
+ batch_size: 160
12
+ learning_rate: 0.0005
13
+
14
+ checkpoint_path: ./save/clipRN50_clips_grammar/clipRN50_clips_grammar
15
+
16
+ use_multi_rewards: true
17
+ use_grammar: true
18
+ use_grammar_baseline: true
19
+ # clip_load_path: '/scratch-space/retrieval/save/clip_negative_text/clip_negative_text-epoch=10.ckpt'
20
+
21
+ # Notice: because I'm to lazy, I reuse the option name for RNNs to set the hyperparameters for transformer:
22
+ # N=num_layers
23
+ # d_model=input_encoding_size
24
+ # d_ff=rnn_size
25
+
26
+ # will be ignored
27
+ num_layers: 6
28
+ input_encoding_size: 512
29
+ rnn_size: 2048
30
+
31
+ # Transformer config
32
+ N_enc: 6
33
+ N_dec: 6
34
+ d_model: 512
35
+ d_ff: 2048
36
+ num_att_heads: 8
37
+ dropout: 0.1
38
+
39
+
40
+ learning_rate_decay_start: 0
41
+ scheduled_sampling_start: -1
42
+ save_checkpoint_every: 3000
43
+ language_eval: 0
44
+ val_images_use: 5000
45
+ max_epochs: 15
46
+ train_sample_n: 5
47
+
48
+ REFORWARD: false
49
+
50
+ # _BASE_: transformer.yml
51
+ reduce_on_plateau: false
52
+ noamopt: false
53
+ learning_rate: 0.000005
54
+ learning_rate_decay_start: -1
55
+
56
+ self_critical_after: 15
57
+ max_epochs: 50
58
+
59
+ verbose: false
60
+ precision: 32
61
+
62
+ # use_clipscore: true
63
+ use_clipscore: false
64
+ clipscore_reward_weight: 2.0
configs/phase2/clipRN50_cider.yml ADDED
@@ -0,0 +1,58 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ caption_model: transformer
2
+ noamopt: true
3
+ noamopt_warmup: 20000
4
+ label_smoothing: 0.0
5
+ input_json: data/cocotalk.json
6
+ input_label_h5: data/cocotalk_label.h5
7
+ input_fc_dir: data/cocotalk_clip_RN50_fc
8
+ input_att_dir: data/cocotalk_clip_RN50_att
9
+ # used only for evaluation
10
+ input_clipscore_vis_dir: data/cocotalk_clipscore_vis
11
+
12
+ seq_per_img: 5
13
+ batch_size: 200
14
+ learning_rate: 0.0005
15
+
16
+ # checkpoint_path: ./save/trans_clip_rn50_sc_pl_scst_cider
17
+ checkpoint_path: save/clipRN50_cider/clipRN50_cider
18
+
19
+ # Notice: because I'm to lazy, I reuse the option name for RNNs to set the hyperparameters for transformer:
20
+ # N=num_layers
21
+ # d_model=input_encoding_size
22
+ # d_ff=rnn_size
23
+
24
+ # will be ignored
25
+ num_layers: 6
26
+ input_encoding_size: 512
27
+ rnn_size: 2048
28
+
29
+ # Transformer config
30
+ N_enc: 6
31
+ N_dec: 6
32
+ d_model: 512
33
+ d_ff: 2048
34
+ num_att_heads: 8
35
+ dropout: 0.1
36
+
37
+
38
+ learning_rate_decay_start: 0
39
+ scheduled_sampling_start: -1
40
+ save_checkpoint_every: 3000
41
+ language_eval: 1
42
+ val_images_use: 5000
43
+ max_epochs: 15
44
+ train_sample_n: 5
45
+
46
+ REFORWARD: false
47
+
48
+ # _BASE_: transformer.yml
49
+ reduce_on_plateau: false
50
+ noamopt: false
51
+ learning_rate: 0.000005
52
+ learning_rate_decay_start: -1
53
+
54
+ self_critical_after: 15
55
+ max_epochs: 40
56
+
57
+ verbose: false
58
+ precision: 32
configs/phase2/clipRN50_cider_clips.yml ADDED
@@ -0,0 +1,61 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ caption_model: transformer
2
+ noamopt: true
3
+ noamopt_warmup: 20000
4
+ label_smoothing: 0.0
5
+ input_json: data/cocotalk.json
6
+ input_label_h5: data/cocotalk_label.h5
7
+ input_fc_dir: data/cocotalk_clip_RN50_fc
8
+ input_att_dir: data/cocotalk_clip_RN50_att
9
+ input_clipscore_vis_dir: data/cocotalk_clipscore_vis
10
+ seq_per_img: 5
11
+ batch_size: 160
12
+ learning_rate: 0.0005
13
+
14
+ checkpoint_path: save/clipRN50_cider_clips/clipRN50_cider_clips
15
+
16
+ # Notice: because I'm to lazy, I reuse the option name for RNNs to set the hyperparameters for transformer:
17
+ # N=num_layers
18
+ # d_model=input_encoding_size
19
+ # d_ff=rnn_size
20
+
21
+ # will be ignored
22
+ num_layers: 6
23
+ input_encoding_size: 512
24
+ rnn_size: 2048
25
+
26
+ # Transformer config
27
+ N_enc: 6
28
+ N_dec: 6
29
+ d_model: 512
30
+ d_ff: 2048
31
+ num_att_heads: 8
32
+ dropout: 0.1
33
+
34
+
35
+ learning_rate_decay_start: 0
36
+ scheduled_sampling_start: -1
37
+ save_checkpoint_every: 3000
38
+ language_eval: 1
39
+ val_images_use: 5000
40
+ max_epochs: 15
41
+ train_sample_n: 5
42
+
43
+ REFORWARD: false
44
+
45
+ # _BASE_: transformer.yml
46
+ reduce_on_plateau: false
47
+ noamopt: false
48
+ learning_rate: 0.000005
49
+ learning_rate_decay_start: -1
50
+
51
+ self_critical_after: 15
52
+ max_epochs: 40
53
+
54
+ verbose: false
55
+ precision: 32
56
+
57
+ use_clipscore: true
58
+ clipscore_reward_weight: 2.0
59
+ clipscore_mode: clip_s
60
+
61
+ use_multi_rewards: true
configs/phase2/clipRN50_clips.yml ADDED
@@ -0,0 +1,58 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ caption_model: transformer
2
+ noamopt: true
3
+ noamopt_warmup: 20000
4
+ label_smoothing: 0.0
5
+ input_json: data/cocotalk.json
6
+ input_label_h5: data/cocotalk_label.h5
7
+ input_fc_dir: data/cocotalk_clip_RN50_fc
8
+ input_att_dir: data/cocotalk_clip_RN50_att
9
+ input_clipscore_vis_dir: data/cocotalk_clipscore_vis
10
+ seq_per_img: 5
11
+ batch_size: 160
12
+ learning_rate: 0.0005
13
+
14
+ checkpoint_path: save/clipRN50_clips/clipRN50_clips
15
+
16
+ # Notice: because I'm to lazy, I reuse the option name for RNNs to set the hyperparameters for transformer:
17
+ # N=num_layers
18
+ # d_model=input_encoding_size
19
+ # d_ff=rnn_size
20
+
21
+ # will be ignored
22
+ num_layers: 6
23
+ input_encoding_size: 512
24
+ rnn_size: 2048
25
+
26
+ # Transformer config
27
+ N_enc: 6
28
+ N_dec: 6
29
+ d_model: 512
30
+ d_ff: 2048
31
+ num_att_heads: 8
32
+ dropout: 0.1
33
+
34
+
35
+ learning_rate_decay_start: 0
36
+ scheduled_sampling_start: -1
37
+ save_checkpoint_every: 3000
38
+ language_eval: 1
39
+ val_images_use: 5000
40
+ max_epochs: 15
41
+ train_sample_n: 5
42
+
43
+ REFORWARD: false
44
+
45
+ # _BASE_: transformer.yml
46
+ reduce_on_plateau: false
47
+ noamopt: false
48
+ learning_rate: 0.000005
49
+ learning_rate_decay_start: -1
50
+
51
+ self_critical_after: 15
52
+ max_epochs: 40
53
+
54
+ verbose: false
55
+ precision: 32
56
+
57
+ use_clipscore: true
58
+ clipscore_reward_weight: 2.0
configs/phase2/clipRN50_clips_grammar.yml ADDED
@@ -0,0 +1,64 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ caption_model: transformer
2
+ noamopt: true
3
+ noamopt_warmup: 20000
4
+ label_smoothing: 0.0
5
+ input_json: data/cocotalk.json
6
+ input_label_h5: data/cocotalk_label.h5
7
+ input_fc_dir: data/cocotalk_clip_RN50_fc
8
+ input_att_dir: data/cocotalk_clip_RN50_att
9
+ input_clipscore_vis_dir: data/cocotalk_clipscore_vis
10
+ seq_per_img: 5
11
+ batch_size: 160
12
+ learning_rate: 0.0005
13
+
14
+ checkpoint_path: save/clipRN50_clips_grammar/clipRN50_clips_grammar
15
+
16
+ use_multi_rewards: true
17
+ use_grammar: true
18
+ use_grammar_baseline: true
19
+ # clip_load_path: '/scratch-space/retrieval/save/clip_negative_text/clip_negative_text-epoch=10.ckpt'
20
+ clip_load_path: 'retrieval/save/clip_negative_text/clip_negative_text-epoch=12.ckpt'
21
+
22
+ # Notice: because I'm to lazy, I reuse the option name for RNNs to set the hyperparameters for transformer:
23
+ # N=num_layers
24
+ # d_model=input_encoding_size
25
+ # d_ff=rnn_size
26
+
27
+ # will be ignored
28
+ num_layers: 6
29
+ input_encoding_size: 512
30
+ rnn_size: 2048
31
+
32
+ # Transformer config
33
+ N_enc: 6
34
+ N_dec: 6
35
+ d_model: 512
36
+ d_ff: 2048
37
+ num_att_heads: 8
38
+ dropout: 0.1
39
+
40
+
41
+ learning_rate_decay_start: 0
42
+ scheduled_sampling_start: -1
43
+ save_checkpoint_every: 3000
44
+ language_eval: 1
45
+ val_images_use: 5000
46
+ max_epochs: 15
47
+ train_sample_n: 5
48
+
49
+ REFORWARD: false
50
+
51
+ # _BASE_: transformer.yml
52
+ reduce_on_plateau: false
53
+ noamopt: false
54
+ learning_rate: 0.000005
55
+ learning_rate_decay_start: -1
56
+
57
+ self_critical_after: 15
58
+ max_epochs: 40
59
+
60
+ verbose: false
61
+ precision: 32
62
+
63
+ use_clipscore: true
64
+ clipscore_reward_weight: 2.0