fenglinliu commited on
Commit
6e32a75
1 Parent(s): c168557

Upload 55 files

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .gitattributes +3 -0
  2. LICENSE +201 -0
  3. PromptNet.py +114 -0
  4. app.py +121 -0
  5. ckpts/few-shot.pth +3 -0
  6. data/annotation.json +3 -0
  7. decoder_config/decoder_config.pkl +3 -0
  8. example_figs/example_fig1.jpg.png +0 -0
  9. example_figs/example_fig2.jpg.jpg +0 -0
  10. example_figs/example_fig3.jpg.png +0 -0
  11. inference.py +110 -0
  12. models/models.py +125 -0
  13. models/r2gen.py +63 -0
  14. modules/att_model.py +319 -0
  15. modules/att_models.py +120 -0
  16. modules/caption_model.py +401 -0
  17. modules/config.pkl +3 -0
  18. modules/dataloader.py +59 -0
  19. modules/dataloaders.py +62 -0
  20. modules/dataset.py +68 -0
  21. modules/datasets.py +57 -0
  22. modules/decoder.py +50 -0
  23. modules/encoder_decoder.py +391 -0
  24. modules/loss.py +22 -0
  25. modules/metrics.py +33 -0
  26. modules/optimizers.py +18 -0
  27. modules/tester.py +144 -0
  28. modules/tokenizers.py +95 -0
  29. modules/trainer.py +255 -0
  30. modules/utils.py +55 -0
  31. modules/visual_extractor.py +53 -0
  32. prompt/prompt.pth +3 -0
  33. pycocoevalcap/README.md +23 -0
  34. pycocoevalcap/__init__.py +1 -0
  35. pycocoevalcap/bleu/LICENSE +19 -0
  36. pycocoevalcap/bleu/__init__.py +1 -0
  37. pycocoevalcap/bleu/bleu.py +57 -0
  38. pycocoevalcap/bleu/bleu_scorer.py +268 -0
  39. pycocoevalcap/cider/__init__.py +1 -0
  40. pycocoevalcap/cider/cider.py +55 -0
  41. pycocoevalcap/cider/cider_scorer.py +197 -0
  42. pycocoevalcap/eval.py +74 -0
  43. pycocoevalcap/license.txt +26 -0
  44. pycocoevalcap/meteor/__init__.py +1 -0
  45. pycocoevalcap/meteor/meteor-1.5.jar +3 -0
  46. pycocoevalcap/meteor/meteor.py +88 -0
  47. pycocoevalcap/rouge/__init__.py +1 -0
  48. pycocoevalcap/rouge/rouge.py +105 -0
  49. pycocoevalcap/tokenizer/__init__.py +1 -0
  50. pycocoevalcap/tokenizer/ptbtokenizer.py +76 -0
.gitattributes CHANGED
@@ -33,3 +33,6 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ data/annotation.json filter=lfs diff=lfs merge=lfs -text
37
+ pycocoevalcap/meteor/meteor-1.5.jar filter=lfs diff=lfs merge=lfs -text
38
+ pycocoevalcap/tokenizer/stanford-corenlp-3.4.1.jar filter=lfs diff=lfs merge=lfs -text
LICENSE ADDED
@@ -0,0 +1,201 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Apache License
2
+ Version 2.0, January 2004
3
+ http://www.apache.org/licenses/
4
+
5
+ TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
6
+
7
+ 1. Definitions.
8
+
9
+ "License" shall mean the terms and conditions for use, reproduction,
10
+ and distribution as defined by Sections 1 through 9 of this document.
11
+
12
+ "Licensor" shall mean the copyright owner or entity authorized by
13
+ the copyright owner that is granting the License.
14
+
15
+ "Legal Entity" shall mean the union of the acting entity and all
16
+ other entities that control, are controlled by, or are under common
17
+ control with that entity. For the purposes of this definition,
18
+ "control" means (i) the power, direct or indirect, to cause the
19
+ direction or management of such entity, whether by contract or
20
+ otherwise, or (ii) ownership of fifty percent (50%) or more of the
21
+ outstanding shares, or (iii) beneficial ownership of such entity.
22
+
23
+ "You" (or "Your") shall mean an individual or Legal Entity
24
+ exercising permissions granted by this License.
25
+
26
+ "Source" form shall mean the preferred form for making modifications,
27
+ including but not limited to software source code, documentation
28
+ source, and configuration files.
29
+
30
+ "Object" form shall mean any form resulting from mechanical
31
+ transformation or translation of a Source form, including but
32
+ not limited to compiled object code, generated documentation,
33
+ and conversions to other media types.
34
+
35
+ "Work" shall mean the work of authorship, whether in Source or
36
+ Object form, made available under the License, as indicated by a
37
+ copyright notice that is included in or attached to the work
38
+ (an example is provided in the Appendix below).
39
+
40
+ "Derivative Works" shall mean any work, whether in Source or Object
41
+ form, that is based on (or derived from) the Work and for which the
42
+ editorial revisions, annotations, elaborations, or other modifications
43
+ represent, as a whole, an original work of authorship. For the purposes
44
+ of this License, Derivative Works shall not include works that remain
45
+ separable from, or merely link (or bind by name) to the interfaces of,
46
+ the Work and Derivative Works thereof.
47
+
48
+ "Contribution" shall mean any work of authorship, including
49
+ the original version of the Work and any modifications or additions
50
+ to that Work or Derivative Works thereof, that is intentionally
51
+ submitted to Licensor for inclusion in the Work by the copyright owner
52
+ or by an individual or Legal Entity authorized to submit on behalf of
53
+ the copyright owner. For the purposes of this definition, "submitted"
54
+ means any form of electronic, verbal, or written communication sent
55
+ to the Licensor or its representatives, including but not limited to
56
+ communication on electronic mailing lists, source code control systems,
57
+ and issue tracking systems that are managed by, or on behalf of, the
58
+ Licensor for the purpose of discussing and improving the Work, but
59
+ excluding communication that is conspicuously marked or otherwise
60
+ designated in writing by the copyright owner as "Not a Contribution."
61
+
62
+ "Contributor" shall mean Licensor and any individual or Legal Entity
63
+ on behalf of whom a Contribution has been received by Licensor and
64
+ subsequently incorporated within the Work.
65
+
66
+ 2. Grant of Copyright License. Subject to the terms and conditions of
67
+ this License, each Contributor hereby grants to You a perpetual,
68
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
69
+ copyright license to reproduce, prepare Derivative Works of,
70
+ publicly display, publicly perform, sublicense, and distribute the
71
+ Work and such Derivative Works in Source or Object form.
72
+
73
+ 3. Grant of Patent License. Subject to the terms and conditions of
74
+ this License, each Contributor hereby grants to You a perpetual,
75
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
76
+ (except as stated in this section) patent license to make, have made,
77
+ use, offer to sell, sell, import, and otherwise transfer the Work,
78
+ where such license applies only to those patent claims licensable
79
+ by such Contributor that are necessarily infringed by their
80
+ Contribution(s) alone or by combination of their Contribution(s)
81
+ with the Work to which such Contribution(s) was submitted. If You
82
+ institute patent litigation against any entity (including a
83
+ cross-claim or counterclaim in a lawsuit) alleging that the Work
84
+ or a Contribution incorporated within the Work constitutes direct
85
+ or contributory patent infringement, then any patent licenses
86
+ granted to You under this License for that Work shall terminate
87
+ as of the date such litigation is filed.
88
+
89
+ 4. Redistribution. You may reproduce and distribute copies of the
90
+ Work or Derivative Works thereof in any medium, with or without
91
+ modifications, and in Source or Object form, provided that You
92
+ meet the following conditions:
93
+
94
+ (a) You must give any other recipients of the Work or
95
+ Derivative Works a copy of this License; and
96
+
97
+ (b) You must cause any modified files to carry prominent notices
98
+ stating that You changed the files; and
99
+
100
+ (c) You must retain, in the Source form of any Derivative Works
101
+ that You distribute, all copyright, patent, trademark, and
102
+ attribution notices from the Source form of the Work,
103
+ excluding those notices that do not pertain to any part of
104
+ the Derivative Works; and
105
+
106
+ (d) If the Work includes a "NOTICE" text file as part of its
107
+ distribution, then any Derivative Works that You distribute must
108
+ include a readable copy of the attribution notices contained
109
+ within such NOTICE file, excluding those notices that do not
110
+ pertain to any part of the Derivative Works, in at least one
111
+ of the following places: within a NOTICE text file distributed
112
+ as part of the Derivative Works; within the Source form or
113
+ documentation, if provided along with the Derivative Works; or,
114
+ within a display generated by the Derivative Works, if and
115
+ wherever such third-party notices normally appear. The contents
116
+ of the NOTICE file are for informational purposes only and
117
+ do not modify the License. You may add Your own attribution
118
+ notices within Derivative Works that You distribute, alongside
119
+ or as an addendum to the NOTICE text from the Work, provided
120
+ that such additional attribution notices cannot be construed
121
+ as modifying the License.
122
+
123
+ You may add Your own copyright statement to Your modifications and
124
+ may provide additional or different license terms and conditions
125
+ for use, reproduction, or distribution of Your modifications, or
126
+ for any such Derivative Works as a whole, provided Your use,
127
+ reproduction, and distribution of the Work otherwise complies with
128
+ the conditions stated in this License.
129
+
130
+ 5. Submission of Contributions. Unless You explicitly state otherwise,
131
+ any Contribution intentionally submitted for inclusion in the Work
132
+ by You to the Licensor shall be under the terms and conditions of
133
+ this License, without any additional terms or conditions.
134
+ Notwithstanding the above, nothing herein shall supersede or modify
135
+ the terms of any separate license agreement you may have executed
136
+ with Licensor regarding such Contributions.
137
+
138
+ 6. Trademarks. This License does not grant permission to use the trade
139
+ names, trademarks, service marks, or product names of the Licensor,
140
+ except as required for reasonable and customary use in describing the
141
+ origin of the Work and reproducing the content of the NOTICE file.
142
+
143
+ 7. Disclaimer of Warranty. Unless required by applicable law or
144
+ agreed to in writing, Licensor provides the Work (and each
145
+ Contributor provides its Contributions) on an "AS IS" BASIS,
146
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
147
+ implied, including, without limitation, any warranties or conditions
148
+ of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
149
+ PARTICULAR PURPOSE. You are solely responsible for determining the
150
+ appropriateness of using or redistributing the Work and assume any
151
+ risks associated with Your exercise of permissions under this License.
152
+
153
+ 8. Limitation of Liability. In no event and under no legal theory,
154
+ whether in tort (including negligence), contract, or otherwise,
155
+ unless required by applicable law (such as deliberate and grossly
156
+ negligent acts) or agreed to in writing, shall any Contributor be
157
+ liable to You for damages, including any direct, indirect, special,
158
+ incidental, or consequential damages of any character arising as a
159
+ result of this License or out of the use or inability to use the
160
+ Work (including but not limited to damages for loss of goodwill,
161
+ work stoppage, computer failure or malfunction, or any and all
162
+ other commercial damages or losses), even if such Contributor
163
+ has been advised of the possibility of such damages.
164
+
165
+ 9. Accepting Warranty or Additional Liability. While redistributing
166
+ the Work or Derivative Works thereof, You may choose to offer,
167
+ and charge a fee for, acceptance of support, warranty, indemnity,
168
+ or other liability obligations and/or rights consistent with this
169
+ License. However, in accepting such obligations, You may act only
170
+ on Your own behalf and on Your sole responsibility, not on behalf
171
+ of any other Contributor, and only if You agree to indemnify,
172
+ defend, and hold each Contributor harmless for any liability
173
+ incurred by, or claims asserted against, such Contributor by reason
174
+ of your accepting any such warranty or additional liability.
175
+
176
+ END OF TERMS AND CONDITIONS
177
+
178
+ APPENDIX: How to apply the Apache License to your work.
179
+
180
+ To apply the Apache License to your work, attach the following
181
+ boilerplate notice, with the fields enclosed by brackets "[]"
182
+ replaced with your own identifying information. (Don't include
183
+ the brackets!) The text should be enclosed in the appropriate
184
+ comment syntax for the file format. We also recommend that a
185
+ file or class name and description of purpose be included on the
186
+ same "printed page" as the copyright notice for easier
187
+ identification within third-party archives.
188
+
189
+ Copyright [yyyy] [name of copyright owner]
190
+
191
+ Licensed under the Apache License, Version 2.0 (the "License");
192
+ you may not use this file except in compliance with the License.
193
+ You may obtain a copy of the License at
194
+
195
+ http://www.apache.org/licenses/LICENSE-2.0
196
+
197
+ Unless required by applicable law or agreed to in writing, software
198
+ distributed under the License is distributed on an "AS IS" BASIS,
199
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
200
+ See the License for the specific language governing permissions and
201
+ limitations under the License.
PromptNet.py ADDED
@@ -0,0 +1,114 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import argparse
3
+ from modules.dataloader import R2DataLoader
4
+ from modules.tokenizers import Tokenizer
5
+ from modules.loss import compute_loss
6
+ from modules.metrics import compute_scores
7
+ from modules.optimizers import build_optimizer, build_lr_scheduler
8
+ from models.models import MedCapModel
9
+ from modules.trainer import Trainer
10
+ import numpy as np
11
+
12
+ def main():
13
+ parser = argparse.ArgumentParser()
14
+
15
+ # Data input Settings
16
+ parser.add_argument('--json_path', default='data/mimic_cxr/annotation.json',
17
+ help='Path to the json file')
18
+ parser.add_argument('--image_dir', default='data/mimic_cxr/images/',
19
+ help='Directory of images')
20
+
21
+ # Dataloader Settings
22
+ parser.add_argument('--dataset', default='mimic_cxr', help='dataset for training MedCap')
23
+ parser.add_argument('--bs', type=int, default=16)
24
+ parser.add_argument('--threshold', type=int, default=10, help='the cut off frequency for the words.')
25
+ parser.add_argument('--num_workers', type=int, default=2, help='the number of workers for dataloader.')
26
+ parser.add_argument('--max_seq_length', type=int, default=1024, help='the maximum sequence length of the reports.')
27
+
28
+ #Trainer Settings
29
+ parser.add_argument('--epochs', type=int, default=30)
30
+ parser.add_argument('--n_gpu', type=int, default=1, help='the number of gpus to be used.')
31
+ parser.add_argument('--save_dir', type=str, default='results/mimic_cxr/', help='the patch to save the models.')
32
+ parser.add_argument('--record_dir', type=str, default='./record_dir/',
33
+ help='the patch to save the results of experiments.')
34
+ parser.add_argument('--log_period', type=int, default=1000, help='the logging interval (in batches).')
35
+ parser.add_argument('--save_period', type=int, default=1)
36
+ parser.add_argument('--monitor_mode', type=str, default='max', choices=['min', 'max'], help='whether to max or min the metric.')
37
+ parser.add_argument('--monitor_metric', type=str, default='BLEU_4', help='the metric to be monitored.')
38
+ parser.add_argument('--early_stop', type=int, default=50, help='the patience of training.')
39
+
40
+ # Training related
41
+ parser.add_argument('--noise_inject', default='no', choices=['yes', 'no'])
42
+
43
+ # Sample related
44
+ parser.add_argument('--sample_method', type=str, default='greedy', help='the sample methods to sample a report.')
45
+ parser.add_argument('--prompt', default='/prompt/prompt.pt')
46
+ parser.add_argument('--prompt_load', default='no',choices=['yes','no'])
47
+
48
+ # Optimization
49
+ parser.add_argument('--optim', type=str, default='Adam', help='the type of the optimizer.')
50
+ parser.add_argument('--lr_ve', type=float, default=1e-5, help='the learning rate for the visual extractor.')
51
+ parser.add_argument('--lr_ed', type=float, default=5e-4, help='the learning rate for the remaining parameters.')
52
+ parser.add_argument('--weight_decay', type=float, default=5e-5, help='the weight decay.')
53
+ parser.add_argument('--adam_betas', type=tuple, default=(0.9, 0.98), help='the weight decay.')
54
+ parser.add_argument('--adam_eps', type=float, default=1e-9, help='the weight decay.')
55
+ parser.add_argument('--amsgrad', type=bool, default=True, help='.')
56
+ parser.add_argument('--noamopt_warmup', type=int, default=5000, help='.')
57
+ parser.add_argument('--noamopt_factor', type=int, default=1, help='.')
58
+
59
+ # Learning Rate Scheduler
60
+ parser.add_argument('--lr_scheduler', type=str, default='StepLR', help='the type of the learning rate scheduler.')
61
+ parser.add_argument('--step_size', type=int, default=50, help='the step size of the learning rate scheduler.')
62
+ parser.add_argument('--gamma', type=float, default=0.1, help='the gamma of the learning rate scheduler.')
63
+
64
+ # Others
65
+ parser.add_argument('--seed', type=int, default=9153, help='.')
66
+ parser.add_argument('--resume', type=str, help='whether to resume the training from existing checkpoints.')
67
+ parser.add_argument('--train_mode', default='base', choices=['base', 'fine-tuning'],
68
+ help='Training mode: base (autoencoding) or fine-tuning (full supervised training or fine-tuned on downstream datasets)')
69
+ parser.add_argument('--F_version', default='v1', choices=['v1', 'v2'],)
70
+ parser.add_argument('--clip_update', default='no' , choices=['yes','no'])
71
+
72
+ # Fine-tuning
73
+ parser.add_argument('--random_init', default='yes', choices=['yes', 'no'],
74
+ help='Whether to load the pre-trained weights for fine-tuning.')
75
+ parser.add_argument('--weight_path', default='path_to_default_weights', type=str,
76
+ help='Path to the pre-trained model weights.')
77
+ args = parser.parse_args()
78
+
79
+ # fix random seeds
80
+ torch.manual_seed(args.seed)
81
+ torch.backends.cudnn.deterministic = True
82
+ torch.backends.cudnn.benchmark = False
83
+ np.random.seed(args.seed)
84
+
85
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
86
+
87
+ # create tokenizer
88
+ tokenizer = Tokenizer(args)
89
+
90
+ # create data loader
91
+ train_dataloader = R2DataLoader(args, tokenizer, split='train', shuffle=True)
92
+ val_dataloader = R2DataLoader(args, tokenizer, split='val', shuffle=False)
93
+ test_dataloader = R2DataLoader(args, tokenizer, split='test', shuffle=False)
94
+
95
+ # get function handles of loss and metrics
96
+ criterion = compute_loss
97
+ metrics = compute_scores
98
+ model = MedCapModel(args, tokenizer)
99
+
100
+ if args.train_mode == 'fine-tuning' and args.random_init == 'no':
101
+ # Load weights from the specified path
102
+ checkpoint = torch.load(args.weight_path)
103
+ model.load_state_dict(checkpoint)
104
+
105
+ # build optimizer, learning rate scheduler
106
+ optimizer = build_optimizer(args, model)
107
+ lr_scheduler = build_lr_scheduler(args, optimizer)
108
+
109
+ # build trainer and start to train
110
+ trainer = Trainer(model, criterion, metrics, optimizer, args, lr_scheduler, train_dataloader, val_dataloader, test_dataloader)
111
+ trainer.train()
112
+
113
+ if __name__ == '__main__':
114
+ main()
app.py ADDED
@@ -0,0 +1,121 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import torch
3
+ from PIL import Image
4
+ from models.r2gen import R2GenModel
5
+ from modules.tokenizers import Tokenizer
6
+ import argparse
7
+
8
+ # Assuming you have a predefined configuration function for model args
9
+ def get_model_args():
10
+ parser = argparse.ArgumentParser()
11
+
12
+ # Model loader settings
13
+ parser.add_argument('--load', type=str, default='ckpts/few-shot.pth', help='the path to the model weights.')
14
+ parser.add_argument('--prompt', type=str, default='prompt/prompt.pth', help='the path to the prompt weights.')
15
+
16
+ # Data input settings
17
+ parser.add_argument('--image_path', type=str, default='example_figs/example_fig1.jpg', help='the path to the test image.')
18
+ parser.add_argument('--image_dir', type=str, default='data/images/', help='the path to the directory containing the data.')
19
+ parser.add_argument('--ann_path', type=str, default='data/annotation.json', help='the path to the directory containing the data.')
20
+
21
+ # Data loader settings
22
+ parser.add_argument('--dataset_name', type=str, default='mimic_cxr', help='the dataset to be used.')
23
+ parser.add_argument('--max_seq_length', type=int, default=60, help='the maximum sequence length of the reports.')
24
+ parser.add_argument('--threshold', type=int, default=3, help='the cut off frequency for the words.')
25
+ parser.add_argument('--num_workers', type=int, default=2, help='the number of workers for dataloader.')
26
+ parser.add_argument('--batch_size', type=int, default=16, help='the number of samples for a batch')
27
+
28
+ # Model settings (for visual extractor)
29
+ parser.add_argument('--visual_extractor', type=str, default='resnet101', help='the visual extractor to be used.')
30
+ parser.add_argument('--visual_extractor_pretrained', type=bool, default=True, help='whether to load the pretrained visual extractor')
31
+
32
+ # Model settings (for Transformer)
33
+ parser.add_argument('--d_model', type=int, default=512, help='the dimension of Transformer.')
34
+ parser.add_argument('--d_ff', type=int, default=512, help='the dimension of FFN.')
35
+ parser.add_argument('--d_vf', type=int, default=2048, help='the dimension of the patch features.')
36
+ parser.add_argument('--num_heads', type=int, default=8, help='the number of heads in Transformer.')
37
+ parser.add_argument('--num_layers', type=int, default=3, help='the number of layers of Transformer.')
38
+ parser.add_argument('--dropout', type=float, default=0.1, help='the dropout rate of Transformer.')
39
+ parser.add_argument('--logit_layers', type=int, default=1, help='the number of the logit layer.')
40
+ parser.add_argument('--bos_idx', type=int, default=0, help='the index of <bos>.')
41
+ parser.add_argument('--eos_idx', type=int, default=0, help='the index of <eos>.')
42
+ parser.add_argument('--pad_idx', type=int, default=0, help='the index of <pad>.')
43
+ parser.add_argument('--use_bn', type=int, default=0, help='whether to use batch normalization.')
44
+ parser.add_argument('--drop_prob_lm', type=float, default=0.5, help='the dropout rate of the output layer.')
45
+ # for Relational Memory
46
+ parser.add_argument('--rm_num_slots', type=int, default=3, help='the number of memory slots.')
47
+ parser.add_argument('--rm_num_heads', type=int, default=8, help='the numebr of heads in rm.')
48
+ parser.add_argument('--rm_d_model', type=int, default=512, help='the dimension of rm.')
49
+
50
+ # Sample related
51
+ parser.add_argument('--sample_method', type=str, default='beam_search', help='the sample methods to sample a report.')
52
+ parser.add_argument('--beam_size', type=int, default=3, help='the beam size when beam searching.')
53
+ parser.add_argument('--temperature', type=float, default=1.0, help='the temperature when sampling.')
54
+ parser.add_argument('--sample_n', type=int, default=1, help='the sample number per image.')
55
+ parser.add_argument('--group_size', type=int, default=1, help='the group size.')
56
+ parser.add_argument('--output_logsoftmax', type=int, default=1, help='whether to output the probabilities.')
57
+ parser.add_argument('--decoding_constraint', type=int, default=0, help='whether decoding constraint.')
58
+ parser.add_argument('--block_trigrams', type=int, default=1, help='whether to use block trigrams.')
59
+
60
+ # Trainer settings
61
+ parser.add_argument('--n_gpu', type=int, default=1, help='the number of gpus to be used.')
62
+ parser.add_argument('--epochs', type=int, default=100, help='the number of training epochs.')
63
+ parser.add_argument('--save_dir', type=str, default='results/iu_xray', help='the patch to save the models.')
64
+ parser.add_argument('--record_dir', type=str, default='records/', help='the patch to save the results of experiments')
65
+ parser.add_argument('--save_period', type=int, default=1, help='the saving period.')
66
+ parser.add_argument('--monitor_mode', type=str, default='max', choices=['min', 'max'], help='whether to max or min the metric.')
67
+ parser.add_argument('--monitor_metric', type=str, default='BLEU_4', help='the metric to be monitored.')
68
+ parser.add_argument('--early_stop', type=int, default=50, help='the patience of training.')
69
+
70
+ # Optimization
71
+ parser.add_argument('--optim', type=str, default='Adam', help='the type of the optimizer.')
72
+ parser.add_argument('--lr_ve', type=float, default=5e-5, help='the learning rate for the visual extractor.')
73
+ parser.add_argument('--lr_ed', type=float, default=1e-4, help='the learning rate for the remaining parameters.')
74
+ parser.add_argument('--weight_decay', type=float, default=5e-5, help='the weight decay.')
75
+ parser.add_argument('--amsgrad', type=bool, default=True, help='.')
76
+
77
+ # Learning Rate Scheduler
78
+ parser.add_argument('--lr_scheduler', type=str, default='StepLR', help='the type of the learning rate scheduler.')
79
+ parser.add_argument('--step_size', type=int, default=50, help='the step size of the learning rate scheduler.')
80
+ parser.add_argument('--gamma', type=float, default=0.1, help='the gamma of the learning rate scheduler.')
81
+
82
+ # Others
83
+ parser.add_argument('--seed', type=int, default=9233, help='.')
84
+ parser.add_argument('--resume', type=str, help='whether to resume the training from existing checkpoints.')
85
+
86
+ args = parser.parse_args()
87
+ return args
88
+
89
+ def load_model():
90
+ args = get_model_args()
91
+ tokenizer = Tokenizer(args)
92
+ device = 'cuda' if torch.cuda.is_available() else 'cpu' # Determine the device dynamically
93
+ model = R2GenModel(args, tokenizer).to(device)
94
+ checkpoint_path = args.load
95
+ # Ensure the state dict is loaded onto the same device as the model
96
+ state_dict = torch.load(checkpoint_path, map_location=device)
97
+ model_state_dict = state_dict['state_dict'] if 'state_dict' in state_dict else state_dict
98
+ model.load_state_dict(model_state_dict)
99
+ model.eval()
100
+ return model, tokenizer
101
+
102
+ model, tokenizer = load_model()
103
+
104
+ def generate_report(image):
105
+ image = Image.fromarray(image).convert('RGB')
106
+ with torch.no_grad():
107
+ output = model([image], mode='sample')
108
+ reports = tokenizer.decode_batch(output.cpu().numpy())
109
+ return reports[0]
110
+
111
+ # Define Gradio interface
112
+ iface = gr.Interface(
113
+ fn=generate_report,
114
+ inputs=gr.inputs.Image(), # Define input shape as needed
115
+ outputs="text",
116
+ title="PromptNet",
117
+ description="Upload a medical image for thorax disease reporting."
118
+ )
119
+
120
+ if __name__ == "__main__":
121
+ iface.launch()
ckpts/few-shot.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:fa4c3ef1a822fdca8895f6ad0c73b4f355b036d0d28a8523aaf51f58c7393f38
3
+ size 1660341639
data/annotation.json ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:5d9590de8db89b0c74343a7e2aecba61e8029e15801de10ec4e030be80b62adc
3
+ size 155745921
decoder_config/decoder_config.pkl ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:c454e6bddb15af52c82734f1796391bf3a10a6c5533ea095de06f661ebb858bb
3
+ size 1744
example_figs/example_fig1.jpg.png ADDED
example_figs/example_fig2.jpg.jpg ADDED
example_figs/example_fig3.jpg.png ADDED
inference.py ADDED
@@ -0,0 +1,110 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from models.r2gen import R2GenModel
3
+ from PIL import Image
4
+ from modules.tokenizers import Tokenizer
5
+ import main
6
+ import argparse
7
+ import json
8
+ import re
9
+ from collections import Counter
10
+
11
+ def parse_agrs():
12
+ parser = argparse.ArgumentParser()
13
+
14
+ # Model loader settings
15
+ parser.add_argument('--load', type=str, default='ckpt/checkpoint.pth', help='the path to the model weights.')
16
+ parser.add_argument('--prompt', type=str, default='ckpt/prompt.pth', help='the path to the prompt weights.')
17
+
18
+ # Data input settings
19
+ parser.add_argument('--image_path', type=str, default='example_figs/fig1.jpg', help='the path to the test image.')
20
+ parser.add_argument('--image_dir', type=str, default='data/images/', help='the path to the directory containing the data.')
21
+ parser.add_argument('--ann_path', type=str, default='data/annotation.json', help='the path to the directory containing the data.')
22
+
23
+ # Data loader settings
24
+ parser.add_argument('--dataset_name', type=str, default='mimic_cxr', help='the dataset to be used.')
25
+ parser.add_argument('--max_seq_length', type=int, default=60, help='the maximum sequence length of the reports.')
26
+ parser.add_argument('--threshold', type=int, default=3, help='the cut off frequency for the words.')
27
+ parser.add_argument('--num_workers', type=int, default=2, help='the number of workers for dataloader.')
28
+ parser.add_argument('--batch_size', type=int, default=16, help='the number of samples for a batch')
29
+
30
+ # Model settings (for visual extractor)
31
+ parser.add_argument('--visual_extractor', type=str, default='resnet101', help='the visual extractor to be used.')
32
+ parser.add_argument('--visual_extractor_pretrained', type=bool, default=True, help='whether to load the pretrained visual extractor')
33
+
34
+ # Model settings (for Transformer)
35
+ parser.add_argument('--d_model', type=int, default=512, help='the dimension of Transformer.')
36
+ parser.add_argument('--d_ff', type=int, default=512, help='the dimension of FFN.')
37
+ parser.add_argument('--d_vf', type=int, default=2048, help='the dimension of the patch features.')
38
+ parser.add_argument('--num_heads', type=int, default=8, help='the number of heads in Transformer.')
39
+ parser.add_argument('--num_layers', type=int, default=3, help='the number of layers of Transformer.')
40
+ parser.add_argument('--dropout', type=float, default=0.1, help='the dropout rate of Transformer.')
41
+ parser.add_argument('--logit_layers', type=int, default=1, help='the number of the logit layer.')
42
+ parser.add_argument('--bos_idx', type=int, default=0, help='the index of <bos>.')
43
+ parser.add_argument('--eos_idx', type=int, default=0, help='the index of <eos>.')
44
+ parser.add_argument('--pad_idx', type=int, default=0, help='the index of <pad>.')
45
+ parser.add_argument('--use_bn', type=int, default=0, help='whether to use batch normalization.')
46
+ parser.add_argument('--drop_prob_lm', type=float, default=0.5, help='the dropout rate of the output layer.')
47
+ # for Relational Memory
48
+ parser.add_argument('--rm_num_slots', type=int, default=3, help='the number of memory slots.')
49
+ parser.add_argument('--rm_num_heads', type=int, default=8, help='the numebr of heads in rm.')
50
+ parser.add_argument('--rm_d_model', type=int, default=512, help='the dimension of rm.')
51
+
52
+ # Sample related
53
+ parser.add_argument('--sample_method', type=str, default='beam_search', help='the sample methods to sample a report.')
54
+ parser.add_argument('--beam_size', type=int, default=3, help='the beam size when beam searching.')
55
+ parser.add_argument('--temperature', type=float, default=1.0, help='the temperature when sampling.')
56
+ parser.add_argument('--sample_n', type=int, default=1, help='the sample number per image.')
57
+ parser.add_argument('--group_size', type=int, default=1, help='the group size.')
58
+ parser.add_argument('--output_logsoftmax', type=int, default=1, help='whether to output the probabilities.')
59
+ parser.add_argument('--decoding_constraint', type=int, default=0, help='whether decoding constraint.')
60
+ parser.add_argument('--block_trigrams', type=int, default=1, help='whether to use block trigrams.')
61
+
62
+ # Trainer settings
63
+ parser.add_argument('--n_gpu', type=int, default=1, help='the number of gpus to be used.')
64
+ parser.add_argument('--epochs', type=int, default=100, help='the number of training epochs.')
65
+ parser.add_argument('--save_dir', type=str, default='results/iu_xray', help='the patch to save the models.')
66
+ parser.add_argument('--record_dir', type=str, default='records/', help='the patch to save the results of experiments')
67
+ parser.add_argument('--save_period', type=int, default=1, help='the saving period.')
68
+ parser.add_argument('--monitor_mode', type=str, default='max', choices=['min', 'max'], help='whether to max or min the metric.')
69
+ parser.add_argument('--monitor_metric', type=str, default='BLEU_4', help='the metric to be monitored.')
70
+ parser.add_argument('--early_stop', type=int, default=50, help='the patience of training.')
71
+
72
+ # Optimization
73
+ parser.add_argument('--optim', type=str, default='Adam', help='the type of the optimizer.')
74
+ parser.add_argument('--lr_ve', type=float, default=5e-5, help='the learning rate for the visual extractor.')
75
+ parser.add_argument('--lr_ed', type=float, default=1e-4, help='the learning rate for the remaining parameters.')
76
+ parser.add_argument('--weight_decay', type=float, default=5e-5, help='the weight decay.')
77
+ parser.add_argument('--amsgrad', type=bool, default=True, help='.')
78
+
79
+ # Learning Rate Scheduler
80
+ parser.add_argument('--lr_scheduler', type=str, default='StepLR', help='the type of the learning rate scheduler.')
81
+ parser.add_argument('--step_size', type=int, default=50, help='the step size of the learning rate scheduler.')
82
+ parser.add_argument('--gamma', type=float, default=0.1, help='the gamma of the learning rate scheduler.')
83
+
84
+ # Others
85
+ parser.add_argument('--seed', type=int, default=9233, help='.')
86
+ parser.add_argument('--resume', type=str, help='whether to resume the training from existing checkpoints.')
87
+
88
+ args = parser.parse_args()
89
+ return args
90
+
91
+
92
+ args = parse_agrs()
93
+ tokenizer = Tokenizer(args)
94
+ image_path=args.image_path
95
+ checkpoint_path = args.load
96
+
97
+ image =[Image.open(image_path).convert('RGB')
98
+ ]
99
+ model=R2GenModel(args ,tokenizer).to('cuda' if torch.cuda.is_available() else 'cpu')
100
+
101
+ state_dict = torch.load(checkpoint_path)
102
+ model_state_dict = state_dict['state_dict']
103
+ model.load_state_dict(model_state_dict).to('cuda' if torch.cuda.is_available() else 'cpu')
104
+
105
+ model.eval()
106
+ with torch.no_grad():
107
+
108
+ output = model(image, mode='sample')
109
+ reports = model.tokenizer.decode_batch(output.cpu().numpy())
110
+ print(reports)
models/models.py ADDED
@@ -0,0 +1,125 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import torch
3
+ import torch.nn as nn
4
+ import pickle
5
+ from typing import Tuple
6
+ from transformers import GPT2LMHeadModel
7
+ from modules.decoder import DeCap
8
+ from medclip import MedCLIPModel, MedCLIPVisionModelViT
9
+ import math
10
+ import pdb
11
+
12
+
13
+ class MedCapModel(nn.Module):
14
+ def __init__(self, args, tokenizer):
15
+ super(MedCapModel, self).__init__()
16
+ self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
17
+ self.args = args
18
+ self.tokenizer = tokenizer
19
+ self.model = DeCap(args, tokenizer)
20
+
21
+ self.align_model = MedCLIPModel(vision_cls=MedCLIPVisionModelViT)
22
+ self.align_model.from_pretrained()
23
+ self.prompt = torch.load(args.prompt)
24
+ if args.dataset == 'iu_xray':
25
+ self.forward = self.forward_iu_xray
26
+ else:
27
+ self.forward = self.forward_mimic_cxr
28
+
29
+ def noise_injection(self, x, variance=0.001, modality_offset=None, dont_norm=False):
30
+ if variance == 0.0:
31
+ return x
32
+ std = math.sqrt(variance)
33
+ if not dont_norm:
34
+ x = torch.nn.functional.normalize(x, dim=1)
35
+ else:
36
+ x = x + (torch.randn(x.shape) * std) # todo by some conventions multivraiance noise should be devided by sqrt of dim
37
+ if modality_offset is not None:
38
+ x = x + modality_offset
39
+ return torch.nn.functional.normalize(x, dim=1)
40
+
41
+ def align_encode_images_iu_xray(self, images):
42
+ # Split the images
43
+ image1, image2 = images.unbind(dim=1)
44
+ # Encode each image
45
+ feature1 = self.align_model.encode_image(image1)
46
+ feature2 = self.align_model.encode_image(image2)
47
+ if self.args.prompt_load == 'yes':
48
+ sim_1 = feature1 @ self.prompt.T.float()
49
+ sim_1 = (sim_1 * 100).softmax(dim=-1)
50
+ prefix_embedding_1 = sim_1 @ self.prompt.float()
51
+ prefix_embedding_1 /= prefix_embedding_1.norm(dim=-1, keepdim=True)
52
+
53
+ sim_2 = feature2 @ self.prompt.T.float()
54
+ sim_2 = (sim_2 * 100).softmax(dim=-1)
55
+ prefix_embedding_2 = sim_2 @ self.prompt.float()
56
+ prefix_embedding_2 /= prefix_embedding_2.norm(dim=-1, keepdim=True)
57
+ averaged_prompt_features = torch.mean(torch.stack([prefix_embedding_1, prefix_embedding_2]), dim=0)
58
+ return averaged_prompt_features
59
+ else:
60
+ # Concatenate the features
61
+ averaged_features = torch.mean(torch.stack([feature1, feature2]), dim=0)
62
+ return averaged_features
63
+
64
+ def align_encode_images_mimic_cxr(self, images):
65
+ feature = self.align_model.encode_image(images)
66
+ if self.args.prompt_load == 'yes':
67
+ sim = feature @ self.prompt.T.float()
68
+ sim = (sim * 100).softmax(dim=-1)
69
+ prefix_embedding = sim @ self.prompt.float()
70
+ prefix_embedding /= prefix_embedding.norm(dim=-1, keepdim=True)
71
+ return prefix_embedding
72
+ else:
73
+ return feature
74
+
75
+ def forward_iu_xray(self, reports_ids, align_ids, align_masks, images, mode='train', update_opts={}):
76
+ self.align_model.to(self.device)
77
+ self.align_model.eval()
78
+ align_ids = align_ids.long()
79
+
80
+ align_image_feature = None
81
+ if self.args.train_mode == 'fine-tuning':
82
+ align_image_feature = self.align_encode_images_iu_xray(images)
83
+ if mode == 'train':
84
+ align_text_feature = self.align_model.encode_text(align_ids, align_masks)
85
+ if self.args.noise_inject == 'yes':
86
+ align_text_feature = self.noise_injection(align_text_feature)
87
+
88
+ if self.args.train_mode == 'fine-tuning':
89
+ if self.args.F_version == 'v1':
90
+ combined_feature = torch.cat([align_text_feature, align_image_feature], dim=-1)
91
+ align_text_feature = self.fc_reduce_dim(combined_feature)
92
+ if self.args.F_version == 'v2':
93
+ align_text_feature = align_image_feature
94
+
95
+ outputs = self.model(align_text_feature, reports_ids, mode='forward')
96
+ logits = outputs.logits
97
+ logits = logits[:, :-1]
98
+ return logits
99
+ elif mode == 'sample':
100
+ align_image_feature = self.align_encode_images_iu_xray(images)
101
+ outputs = self.model(align_image_feature, reports_ids, mode='sample', update_opts=update_opts)
102
+ return outputs
103
+ else:
104
+ raise ValueError
105
+
106
+ def forward_mimic_cxr(self, reports_ids, align_ids, align_masks, images, mode='train', update_opts={}):
107
+ self.align_model.to(self.device)
108
+ self.align_model.eval()
109
+ align_ids = align_ids.long()
110
+ if mode == 'train':
111
+ if self.args.noise_inject == 'yes':
112
+ align_text_feature = self.align_model.encode_text(align_ids, align_masks)
113
+ align_text_feature = self.noise_injection(align_text_feature)
114
+ else:
115
+ align_text_feature = self.align_model.encode_text(align_ids, align_masks)
116
+ outputs = self.model(align_text_feature, reports_ids, mode='forward')
117
+ logits = outputs.logits
118
+ logits = logits[:, :-1]
119
+ return logits
120
+ elif mode == 'sample':
121
+ align_image_feature = self.align_encode_images_mimic_cxr(images)
122
+ outputs = self.model(align_image_feature, reports_ids, mode='sample', update_opts=update_opts)
123
+ return outputs
124
+ else:
125
+ raise ValueError
models/r2gen.py ADDED
@@ -0,0 +1,63 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import numpy as np
4
+
5
+ from modules.visual_extractor import VisualExtractor
6
+ from modules.encoder_decoder import EncoderDecoder
7
+ import torch.nn.functional as F
8
+
9
+ class R2GenModel(nn.Module):
10
+ def __init__(self, args, tokenizer):
11
+ super(R2GenModel, self).__init__()
12
+ self.args = args
13
+ self.tokenizer = tokenizer
14
+ self.visual_extractor = VisualExtractor(args)
15
+ self.encoder_decoder = EncoderDecoder(args, tokenizer)
16
+ if args.dataset_name == 'iu_xray':
17
+ self.forward = self.forward_iu_xray
18
+ else:
19
+ self.forward = self.forward_mimic_cxr
20
+ self.affine_a = nn.Linear(1024, 2048)
21
+ self.affine_b = nn.Linear(1024, 2048)
22
+ self.affine_c = nn.Linear(1024, 2048)
23
+ self.affine_d = nn.Linear(1024, 2048)
24
+ self.affine_aa = nn.Linear(1024, 2048)
25
+ self.affine_bb = nn.Linear(1024, 2048)
26
+
27
+ def __str__(self):
28
+ model_parameters = filter(lambda p: p.requires_grad, self.parameters())
29
+ params = sum([np.prod(p.size()) for p in model_parameters])
30
+ return super().__str__() + '\nTrainable parameters: {}'.format(params)
31
+
32
+ def forward_iu_xray(self, images, targets=None, mode='train'):
33
+ att_feats_0, fc_feats_0 = self.visual_extractor(images[:, 0])
34
+ att_feats_1, fc_feats_1 = self.visual_extractor(images[:, 1])
35
+ #new add
36
+ att_feats_0=F.relu(self.affine_a(att_feats_0))
37
+ fc_feats_0=F.relu(self.affine_b(fc_feats_0))
38
+ att_feats_1=F.relu(self.affine_c(att_feats_1))
39
+ fc_feats_1=F.relu(self.affine_d(fc_feats_1))
40
+
41
+ fc_feats = torch.cat((fc_feats_0, fc_feats_1), dim=1)
42
+ att_feats = torch.cat((att_feats_0, att_feats_1), dim=1)
43
+ if mode == 'train':
44
+ output = self.encoder_decoder(fc_feats, att_feats, targets, mode='forward')
45
+ elif mode == 'sample':
46
+ output, _ = self.encoder_decoder(fc_feats, att_feats, mode='sample')
47
+ else:
48
+ raise ValueError
49
+ return output
50
+
51
+ def forward_mimic_cxr(self, images, targets=None, mode='train'):
52
+ att_feats1, fc_feats1 = self.visual_extractor(images)
53
+ att_feats=F.relu(self.affine_aa(att_feats1))
54
+ fc_feats=F.relu(self.affine_bb(fc_feats1))
55
+
56
+ if mode == 'train':
57
+ output = self.encoder_decoder(fc_feats, att_feats, targets, mode='forward')
58
+ elif mode == 'sample':
59
+ output, _ = self.encoder_decoder(fc_feats, att_feats, mode='sample')
60
+ else:
61
+ raise ValueError
62
+ return output
63
+
modules/att_model.py ADDED
@@ -0,0 +1,319 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import absolute_import
2
+ from __future__ import division
3
+ from __future__ import print_function
4
+
5
+ import torch
6
+ import torch.nn as nn
7
+ import torch.nn.functional as F
8
+ from torch.nn.utils.rnn import PackedSequence, pack_padded_sequence, pad_packed_sequence
9
+
10
+ import modules.utils as utils
11
+ from modules.caption_model import CaptionModel
12
+
13
+
14
+ def sort_pack_padded_sequence(input, lengths):
15
+ sorted_lengths, indices = torch.sort(lengths, descending=True)
16
+ tmp = pack_padded_sequence(input[indices], sorted_lengths, batch_first=True)
17
+ inv_ix = indices.clone()
18
+ inv_ix[indices] = torch.arange(0, len(indices)).type_as(inv_ix)
19
+ return tmp, inv_ix
20
+
21
+
22
+ def pad_unsort_packed_sequence(input, inv_ix):
23
+ tmp, _ = pad_packed_sequence(input, batch_first=True)
24
+ tmp = tmp[inv_ix]
25
+ return tmp
26
+
27
+
28
+ def pack_wrapper(module, att_feats, att_masks):
29
+ if att_masks is not None:
30
+ packed, inv_ix = sort_pack_padded_sequence(att_feats, att_masks.data.long().sum(1))
31
+ return pad_unsort_packed_sequence(PackedSequence(module(packed[0]), packed[1]), inv_ix)
32
+ else:
33
+ return module(att_feats)
34
+
35
+
36
+ class AttModel(CaptionModel):
37
+ def __init__(self, args, tokenizer):
38
+ super(AttModel, self).__init__()
39
+ self.args = args
40
+ self.tokenizer = tokenizer
41
+ self.vocab_size = len(tokenizer.idx2token)
42
+ self.input_encoding_size = args.d_model
43
+ self.rnn_size = args.d_ff
44
+ self.num_layers = args.num_layers
45
+ self.drop_prob_lm = args.drop_prob_lm
46
+ self.max_seq_length = args.max_seq_length
47
+ self.att_feat_size = args.d_vf
48
+ self.att_hid_size = args.d_model
49
+
50
+ self.bos_idx = args.bos_idx
51
+ self.eos_idx = args.eos_idx
52
+ self.pad_idx = args.pad_idx
53
+
54
+ self.use_bn = args.use_bn
55
+
56
+ self.embed = lambda x: x
57
+ self.fc_embed = lambda x: x
58
+ self.att_embed = nn.Sequential(*(
59
+ ((nn.BatchNorm1d(self.att_feat_size),) if self.use_bn else ()) +
60
+ (nn.Linear(self.att_feat_size, self.input_encoding_size),
61
+ nn.ReLU(),
62
+ nn.Dropout(self.drop_prob_lm)) +
63
+ ((nn.BatchNorm1d(self.input_encoding_size),) if self.use_bn == 2 else ())))
64
+
65
+ def clip_att(self, att_feats, att_masks):
66
+ # Clip the length of att_masks and att_feats to the maximum length
67
+ if att_masks is not None:
68
+ max_len = att_masks.data.long().sum(1).max()
69
+ att_feats = att_feats[:, :max_len].contiguous()
70
+ att_masks = att_masks[:, :max_len].contiguous()
71
+ return att_feats, att_masks
72
+
73
+ def _prepare_feature(self, fc_feats, att_feats, att_masks):
74
+ att_feats, att_masks = self.clip_att(att_feats, att_masks)
75
+
76
+ # embed fc and att feats
77
+ fc_feats = self.fc_embed(fc_feats)
78
+ att_feats = pack_wrapper(self.att_embed, att_feats, att_masks)
79
+
80
+ # Project the attention feats first to reduce memory and computation comsumptions.
81
+ p_att_feats = self.ctx2att(att_feats)
82
+
83
+ return fc_feats, att_feats, p_att_feats, att_masks
84
+
85
+ def get_logprobs_state(self, it, fc_feats, att_feats, p_att_feats, att_masks, state, output_logsoftmax=1):
86
+ # 'it' contains a word index
87
+ xt = self.embed(it)
88
+
89
+ output, state = self.core(xt, fc_feats, att_feats, p_att_feats, state, att_masks)
90
+ if output_logsoftmax:
91
+ logprobs = F.log_softmax(self.logit(output), dim=1)
92
+ else:
93
+ logprobs = self.logit(output)
94
+
95
+ return logprobs, state
96
+
97
+ def _sample_beam(self, fc_feats, att_feats, att_masks=None, opt={}):
98
+ beam_size = opt.get('beam_size', 10)
99
+ group_size = opt.get('group_size', 1)
100
+ sample_n = opt.get('sample_n', 10)
101
+ # when sample_n == beam_size then each beam is a sample.
102
+ assert sample_n == 1 or sample_n == beam_size // group_size, 'when beam search, sample_n == 1 or beam search'
103
+ batch_size = fc_feats.size(0)
104
+
105
+ p_fc_feats, p_att_feats, pp_att_feats, p_att_masks = self._prepare_feature(fc_feats, att_feats, att_masks)
106
+
107
+ assert beam_size <= self.vocab_size + 1, 'lets assume this for now, otherwise this corner case causes a few headaches down the road. can be dealt with in future if needed'
108
+ seq = fc_feats.new_full((batch_size * sample_n, self.max_seq_length), self.pad_idx, dtype=torch.long)
109
+ seqLogprobs = fc_feats.new_zeros(batch_size * sample_n, self.max_seq_length, self.vocab_size + 1)
110
+ # lets process every image independently for now, for simplicity
111
+
112
+ self.done_beams = [[] for _ in range(batch_size)]
113
+
114
+ state = self.init_hidden(batch_size)
115
+
116
+ # first step, feed bos
117
+ it = fc_feats.new_full([batch_size], self.bos_idx, dtype=torch.long)
118
+ logprobs, state = self.get_logprobs_state(it, p_fc_feats, p_att_feats, pp_att_feats, p_att_masks, state)
119
+
120
+ p_fc_feats, p_att_feats, pp_att_feats, p_att_masks = utils.repeat_tensors(beam_size,
121
+ [p_fc_feats, p_att_feats,
122
+ pp_att_feats, p_att_masks]
123
+ )
124
+ self.done_beams = self.beam_search(state, logprobs, p_fc_feats, p_att_feats, pp_att_feats, p_att_masks, opt=opt)
125
+ for k in range(batch_size):
126
+ if sample_n == beam_size:
127
+ for _n in range(sample_n):
128
+ seq_len = self.done_beams[k][_n]['seq'].shape[0]
129
+ seq[k * sample_n + _n, :seq_len] = self.done_beams[k][_n]['seq']
130
+ seqLogprobs[k * sample_n + _n, :seq_len] = self.done_beams[k][_n]['logps']
131
+ else:
132
+ seq_len = self.done_beams[k][0]['seq'].shape[0]
133
+ seq[k, :seq_len] = self.done_beams[k][0]['seq'] # the first beam has highest cumulative score
134
+ seqLogprobs[k, :seq_len] = self.done_beams[k][0]['logps']
135
+ # return the samples and their log likelihoods
136
+ return seq, seqLogprobs
137
+
138
+ def _sample(self, fc_feats, att_feats, att_masks=None):
139
+ opt = self.args.__dict__
140
+ sample_method = opt.get('sample_method', 'greedy')
141
+ beam_size = opt.get('beam_size', 1)
142
+ temperature = opt.get('temperature', 1.0)
143
+ sample_n = int(opt.get('sample_n', 1))
144
+ group_size = opt.get('group_size', 1)
145
+ output_logsoftmax = opt.get('output_logsoftmax', 1)
146
+ decoding_constraint = opt.get('decoding_constraint', 0)
147
+ block_trigrams = opt.get('block_trigrams', 0)
148
+ if beam_size > 1 and sample_method in ['greedy', 'beam_search']:
149
+ return self._sample_beam(fc_feats, att_feats, att_masks, opt)
150
+ if group_size > 1:
151
+ return self._diverse_sample(fc_feats, att_feats, att_masks, opt)
152
+
153
+ batch_size = fc_feats.size(0)
154
+ state = self.init_hidden(batch_size * sample_n)
155
+
156
+ p_fc_feats, p_att_feats, pp_att_feats, p_att_masks = self._prepare_feature(fc_feats, att_feats, att_masks)
157
+
158
+ if sample_n > 1:
159
+ p_fc_feats, p_att_feats, pp_att_feats, p_att_masks = utils.repeat_tensors(sample_n,
160
+ [p_fc_feats, p_att_feats,
161
+ pp_att_feats, p_att_masks]
162
+ )
163
+
164
+ trigrams = [] # will be a list of batch_size dictionaries
165
+
166
+ seq = fc_feats.new_full((batch_size * sample_n, self.max_seq_length), self.pad_idx, dtype=torch.long)
167
+ seqLogprobs = fc_feats.new_zeros(batch_size * sample_n, self.max_seq_length, self.vocab_size + 1)
168
+ for t in range(self.max_seq_length + 1):
169
+ if t == 0: # input <bos>
170
+ it = fc_feats.new_full([batch_size * sample_n], self.bos_idx, dtype=torch.long)
171
+
172
+ logprobs, state = self.get_logprobs_state(it, p_fc_feats, p_att_feats, pp_att_feats, p_att_masks, state,
173
+ output_logsoftmax=output_logsoftmax)
174
+
175
+ if decoding_constraint and t > 0:
176
+ tmp = logprobs.new_zeros(logprobs.size())
177
+ tmp.scatter_(1, seq[:, t - 1].data.unsqueeze(1), float('-inf'))
178
+ logprobs = logprobs + tmp
179
+
180
+ # Mess with trigrams
181
+ # Copy from https://github.com/lukemelas/image-paragraph-captioning
182
+ if block_trigrams and t >= 3:
183
+ # Store trigram generated at last step
184
+ prev_two_batch = seq[:, t - 3:t - 1]
185
+ for i in range(batch_size): # = seq.size(0)
186
+ prev_two = (prev_two_batch[i][0].item(), prev_two_batch[i][1].item())
187
+ current = seq[i][t - 1]
188
+ if t == 3: # initialize
189
+ trigrams.append({prev_two: [current]}) # {LongTensor: list containing 1 int}
190
+ elif t > 3:
191
+ if prev_two in trigrams[i]: # add to list
192
+ trigrams[i][prev_two].append(current)
193
+ else: # create list
194
+ trigrams[i][prev_two] = [current]
195
+ # Block used trigrams at next step
196
+ prev_two_batch = seq[:, t - 2:t]
197
+ mask = torch.zeros(logprobs.size(), requires_grad=False).cuda() # batch_size x vocab_size
198
+ for i in range(batch_size):
199
+ prev_two = (prev_two_batch[i][0].item(), prev_two_batch[i][1].item())
200
+ if prev_two in trigrams[i]:
201
+ for j in trigrams[i][prev_two]:
202
+ mask[i, j] += 1
203
+ # Apply mask to log probs
204
+ # logprobs = logprobs - (mask * 1e9)
205
+ alpha = 2.0 # = 4
206
+ logprobs = logprobs + (mask * -0.693 * alpha) # ln(1/2) * alpha (alpha -> infty works best)
207
+
208
+ # sample the next word
209
+ if t == self.max_seq_length: # skip if we achieve maximum length
210
+ break
211
+ it, sampleLogprobs = self.sample_next_word(logprobs, sample_method, temperature)
212
+
213
+ # stop when all finished
214
+ if t == 0:
215
+ unfinished = it != self.eos_idx
216
+ else:
217
+ it[~unfinished] = self.pad_idx # This allows eos_idx not being overwritten to 0
218
+ logprobs = logprobs * unfinished.unsqueeze(1).float()
219
+ unfinished = unfinished * (it != self.eos_idx)
220
+ seq[:, t] = it
221
+ seqLogprobs[:, t] = logprobs
222
+ # quit loop if all sequences have finished
223
+ if unfinished.sum() == 0:
224
+ break
225
+
226
+ return seq, seqLogprobs
227
+
228
+ def _diverse_sample(self, fc_feats, att_feats, att_masks=None, opt={}):
229
+
230
+ sample_method = opt.get('sample_method', 'greedy')
231
+ beam_size = opt.get('beam_size', 1)
232
+ temperature = opt.get('temperature', 1.0)
233
+ group_size = opt.get('group_size', 1)
234
+ diversity_lambda = opt.get('diversity_lambda', 0.5)
235
+ decoding_constraint = opt.get('decoding_constraint', 0)
236
+ block_trigrams = opt.get('block_trigrams', 0)
237
+
238
+ batch_size = fc_feats.size(0)
239
+ state = self.init_hidden(batch_size)
240
+
241
+ p_fc_feats, p_att_feats, pp_att_feats, p_att_masks = self._prepare_feature(fc_feats, att_feats, att_masks)
242
+
243
+ trigrams_table = [[] for _ in range(group_size)] # will be a list of batch_size dictionaries
244
+
245
+ seq_table = [fc_feats.new_full((batch_size, self.max_seq_length), self.pad_idx, dtype=torch.long) for _ in
246
+ range(group_size)]
247
+ seqLogprobs_table = [fc_feats.new_zeros(batch_size, self.max_seq_length) for _ in range(group_size)]
248
+ state_table = [self.init_hidden(batch_size) for _ in range(group_size)]
249
+
250
+ for tt in range(self.max_seq_length + group_size):
251
+ for divm in range(group_size):
252
+ t = tt - divm
253
+ seq = seq_table[divm]
254
+ seqLogprobs = seqLogprobs_table[divm]
255
+ trigrams = trigrams_table[divm]
256
+ if t >= 0 and t <= self.max_seq_length - 1:
257
+ if t == 0: # input <bos>
258
+ it = fc_feats.new_full([batch_size], self.bos_idx, dtype=torch.long)
259
+ else:
260
+ it = seq[:, t - 1] # changed
261
+
262
+ logprobs, state_table[divm] = self.get_logprobs_state(it, p_fc_feats, p_att_feats, pp_att_feats,
263
+ p_att_masks, state_table[divm]) # changed
264
+ logprobs = F.log_softmax(logprobs / temperature, dim=-1)
265
+
266
+ # Add diversity
267
+ if divm > 0:
268
+ unaug_logprobs = logprobs.clone()
269
+ for prev_choice in range(divm):
270
+ prev_decisions = seq_table[prev_choice][:, t]
271
+ logprobs[:, prev_decisions] = logprobs[:, prev_decisions] - diversity_lambda
272
+
273
+ if decoding_constraint and t > 0:
274
+ tmp = logprobs.new_zeros(logprobs.size())
275
+ tmp.scatter_(1, seq[:, t - 1].data.unsqueeze(1), float('-inf'))
276
+ logprobs = logprobs + tmp
277
+
278
+ # Mess with trigrams
279
+ if block_trigrams and t >= 3:
280
+ # Store trigram generated at last step
281
+ prev_two_batch = seq[:, t - 3:t - 1]
282
+ for i in range(batch_size): # = seq.size(0)
283
+ prev_two = (prev_two_batch[i][0].item(), prev_two_batch[i][1].item())
284
+ current = seq[i][t - 1]
285
+ if t == 3: # initialize
286
+ trigrams.append({prev_two: [current]}) # {LongTensor: list containing 1 int}
287
+ elif t > 3:
288
+ if prev_two in trigrams[i]: # add to list
289
+ trigrams[i][prev_two].append(current)
290
+ else: # create list
291
+ trigrams[i][prev_two] = [current]
292
+ # Block used trigrams at next step
293
+ prev_two_batch = seq[:, t - 2:t]
294
+ mask = torch.zeros(logprobs.size(), requires_grad=False).cuda() # batch_size x vocab_size
295
+ for i in range(batch_size):
296
+ prev_two = (prev_two_batch[i][0].item(), prev_two_batch[i][1].item())
297
+ if prev_two in trigrams[i]:
298
+ for j in trigrams[i][prev_two]:
299
+ mask[i, j] += 1
300
+ # Apply mask to log probs
301
+ # logprobs = logprobs - (mask * 1e9)
302
+ alpha = 2.0 # = 4
303
+ logprobs = logprobs + (mask * -0.693 * alpha) # ln(1/2) * alpha (alpha -> infty works best)
304
+
305
+ it, sampleLogprobs = self.sample_next_word(logprobs, sample_method, 1)
306
+
307
+ # stop when all finished
308
+ if t == 0:
309
+ unfinished = it != self.eos_idx
310
+ else:
311
+ unfinished = seq[:, t - 1] != self.pad_idx & seq[:, t - 1] != self.eos_idx
312
+ it[~unfinished] = self.pad_idx
313
+ unfinished = unfinished & (it != self.eos_idx) # changed
314
+ seq[:, t] = it
315
+ seqLogprobs[:, t] = sampleLogprobs.view(-1)
316
+
317
+ return torch.stack(seq_table, 1).reshape(batch_size * group_size, -1), torch.stack(seqLogprobs_table,
318
+ 1).reshape(
319
+ batch_size * group_size, -1)
modules/att_models.py ADDED
@@ -0,0 +1,120 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import absolute_import
2
+ from __future__ import division
3
+ from __future__ import print_function
4
+
5
+ import pdb
6
+
7
+ import torch
8
+ import torch.nn as nn
9
+ import torch.nn.functional as F
10
+ from torch.nn.utils.rnn import PackedSequence, pack_padded_sequence, pad_packed_sequence
11
+
12
+ import modules.utils as utils
13
+ from modules.caption_model import CaptionModel
14
+
15
+
16
+ class AttModel(CaptionModel):
17
+ def __init__(self, args, tokenizer):
18
+ super(AttModel, self).__init__()
19
+ self.args = args
20
+ self.tokenizer = tokenizer
21
+ self.vocab_size = len(tokenizer.idx2token)
22
+ self.max_seq_length = 60
23
+
24
+ def _sample(self, clip_features, gpt_tokens,update_opts={}):
25
+
26
+ opt = self.args.__dict__
27
+ opt.update(**update_opts)
28
+ sample_method = opt.get('sample_method', 'greedy')
29
+
30
+
31
+ if sample_method == 'greedy':
32
+ return self._greedy_sample(clip_features, gpt_tokens)
33
+ elif sample_method == 'beam_search':
34
+ return self._beam_search_sample(clip_features, gpt_tokens)
35
+ else:
36
+ raise ValueError("Unknown sample_method: " + sample_method)
37
+
38
+ def _greedy_sample(self, clip_features, gpt_tokens, temperature=1.0):
39
+ #input_ids = torch.full((clip_features.size(0), 1), self.tokenizer.bos_token_id).type_as(clip_features).long()
40
+ clip_features = self.clip_project(clip_features).reshape(clip_features.size(0), 1, -1)
41
+ tokens = [None for _ in range(clip_features.size(0))]
42
+ finished = [False for _ in range(clip_features.size(0))]
43
+ max_length = 200
44
+ for _ in range(max_length):
45
+ outputs = self.decoder(inputs_embeds= clip_features)
46
+ logits = outputs.logits[:, -1, :] / (temperature if temperature > 0 else 1.0)
47
+ next_tokens = torch.argmax(logits, -1).unsqueeze(1)
48
+ next_token_embeds = self.decoder.transformer.wte(next_tokens)
49
+ for j in range(clip_features.size(0)):
50
+ if finished[j]:
51
+ continue
52
+ if tokens[j] is None:
53
+ tokens[j] = next_tokens[j]
54
+ else:
55
+ tokens[j] = torch.cat((tokens[j], next_tokens[j]), dim=0)
56
+ if next_tokens[j].item() == self.tokenizer.eos_token_id:
57
+ finished[j] = True
58
+ clip_features = torch.cat((clip_features, next_token_embeds), dim=1)
59
+ outputs = []
60
+ for token in tokens:
61
+ try:
62
+ output_list = token.squeeze().cpu().numpy().tolist()
63
+ # Pad or truncate output_list to max_length
64
+ output_list = (output_list + [self.tokenizer.pad_token_id] * max_length)[:max_length]
65
+ except Exception as e:
66
+ print(f"Error during decoding: {type(e).__name__}: {e}")
67
+ output_list = [self.tokenizer.pad_token_id] * max_length
68
+ outputs.append(output_list)
69
+
70
+ # Convert list of lists to tensor
71
+ outputs = torch.tensor(outputs, device=clip_features.device)
72
+ return outputs
73
+
74
+
75
+ def _beam_search_sample(self, clip_features, gpt_tokens, beam_size=5):
76
+ batch_size = clip_features.size(0)
77
+ # Prepare the first input for every beam
78
+ input_ids = torch.full((batch_size*beam_size, 1), self.tokenizer.bos_token_id).type_as(clip_features).long()
79
+ beam_scores = torch.zeros((batch_size, beam_size)).type_as(clip_features)
80
+ done = [False]*batch_size
81
+
82
+ for _ in range(self.max_seq_length):
83
+ outputs = self._forward(clip_features.repeat_interleave(beam_size, 0), input_ids)
84
+ next_token_logits = outputs.logits[:, -1, :]
85
+ next_token_probs = F.softmax(next_token_logits, dim=-1)
86
+
87
+ # Apply a mask for already finished beams
88
+ next_token_probs[done] = 0
89
+ next_token_probs[:, self.tokenizer.eos_token_id] = -float('Inf')
90
+
91
+ # Multiply old scores with new probabilities
92
+ scores = beam_scores.unsqueeze(2) * next_token_probs
93
+ scores = scores.view(batch_size, -1)
94
+
95
+ # Get the top beam_size scores and their respective indices
96
+ top_scores, top_indices = scores.topk(beam_size, dim=1)
97
+
98
+ # Update beam scores
99
+ beam_scores = top_scores.log()
100
+
101
+ # Reshape input_ids
102
+ input_ids = input_ids.view(batch_size, beam_size, -1)
103
+
104
+ # Compute next inputs
105
+ next_token_ids = top_indices % self.vocab_size
106
+ beam_indices = top_indices // self.vocab_size
107
+ next_input_ids = torch.cat([input_ids.gather(1, beam_indices.unsqueeze(2).expand(-1, -1, input_ids.size(2))), next_token_ids.unsqueeze(2)], dim=2)
108
+
109
+ # Flatten input_ids
110
+ input_ids = next_input_ids.view(batch_size*beam_size, -1)
111
+
112
+ # Check which beams are done
113
+ done = (next_token_ids == self.tokenizer.eos_token_id).all(dim=1).tolist()
114
+
115
+ if all(done):
116
+ break
117
+
118
+ return input_ids.view(batch_size, beam_size, -1)
119
+
120
+
modules/caption_model.py ADDED
@@ -0,0 +1,401 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import absolute_import
2
+ from __future__ import division
3
+ from __future__ import print_function
4
+
5
+ import torch
6
+ import torch.nn as nn
7
+ import torch.nn.functional as F
8
+
9
+ import modules.utils as utils
10
+
11
+
12
+ class CaptionModel(nn.Module):
13
+ def __init__(self):
14
+ super(CaptionModel, self).__init__()
15
+
16
+ # implements beam search
17
+ # calls beam_step and returns the final set of beams
18
+ # augments log-probabilities with diversity terms when number of groups > 1
19
+
20
+ def forward(self, *args, **kwargs):
21
+ mode = kwargs.get('mode', 'forward')
22
+ if 'mode' in kwargs:
23
+ del kwargs['mode']
24
+ return getattr(self, '_' + mode)(*args, **kwargs)
25
+
26
+ def beam_search(self, init_state, init_logprobs, *args, **kwargs):
27
+
28
+ # function computes the similarity score to be augmented
29
+ def add_diversity(beam_seq_table, logprobs, t, divm, diversity_lambda, bdash):
30
+ local_time = t - divm
31
+ unaug_logprobs = logprobs.clone()
32
+ batch_size = beam_seq_table[0].shape[0]
33
+
34
+ if divm > 0:
35
+ change = logprobs.new_zeros(batch_size, logprobs.shape[-1])
36
+ for prev_choice in range(divm):
37
+ prev_decisions = beam_seq_table[prev_choice][:, :, local_time] # Nxb
38
+ for prev_labels in range(bdash):
39
+ change.scatter_add_(1, prev_decisions[:, prev_labels].unsqueeze(-1),
40
+ change.new_ones(batch_size, 1))
41
+
42
+ if local_time == 0:
43
+ logprobs = logprobs - change * diversity_lambda
44
+ else:
45
+ logprobs = logprobs - self.repeat_tensor(bdash, change) * diversity_lambda
46
+
47
+ return logprobs, unaug_logprobs
48
+
49
+ # does one step of classical beam search
50
+
51
+ def beam_step(logprobs, unaug_logprobs, beam_size, t, beam_seq, beam_seq_logprobs, beam_logprobs_sum, state):
52
+ # INPUTS:
53
+ # logprobs: probabilities augmented after diversity N*bxV
54
+ # beam_size: obvious
55
+ # t : time instant
56
+ # beam_seq : tensor contanining the beams
57
+ # beam_seq_logprobs: tensor contanining the beam logprobs
58
+ # beam_logprobs_sum: tensor contanining joint logprobs
59
+ # OUPUTS:
60
+ # beam_seq : tensor containing the word indices of the decoded captions Nxbxl
61
+ # beam_seq_logprobs : log-probability of each decision made, NxbxlxV
62
+ # beam_logprobs_sum : joint log-probability of each beam Nxb
63
+
64
+ batch_size = beam_logprobs_sum.shape[0]
65
+ vocab_size = logprobs.shape[-1]
66
+ logprobs = logprobs.reshape(batch_size, -1, vocab_size) # NxbxV
67
+ if t == 0:
68
+ assert logprobs.shape[1] == 1
69
+ beam_logprobs_sum = beam_logprobs_sum[:, :1]
70
+ candidate_logprobs = beam_logprobs_sum.unsqueeze(-1) + logprobs # beam_logprobs_sum Nxb logprobs is NxbxV
71
+ ys, ix = torch.sort(candidate_logprobs.reshape(candidate_logprobs.shape[0], -1), -1, True)
72
+ ys, ix = ys[:, :beam_size], ix[:, :beam_size]
73
+ beam_ix = ix // vocab_size # Nxb which beam
74
+ selected_ix = ix % vocab_size # Nxb # which world
75
+ state_ix = (beam_ix + torch.arange(batch_size).type_as(beam_ix).unsqueeze(-1) * logprobs.shape[1]).reshape(
76
+ -1) # N*b which in Nxb beams
77
+
78
+ if t > 0:
79
+ # gather according to beam_ix
80
+ assert (beam_seq.gather(1, beam_ix.unsqueeze(-1).expand_as(beam_seq)) ==
81
+ beam_seq.reshape(-1, beam_seq.shape[-1])[state_ix].view_as(beam_seq)).all()
82
+ beam_seq = beam_seq.gather(1, beam_ix.unsqueeze(-1).expand_as(beam_seq))
83
+
84
+ beam_seq_logprobs = beam_seq_logprobs.gather(1, beam_ix.unsqueeze(-1).unsqueeze(-1).expand_as(
85
+ beam_seq_logprobs))
86
+
87
+ beam_seq = torch.cat([beam_seq, selected_ix.unsqueeze(-1)], -1) # beam_seq Nxbxl
88
+ beam_logprobs_sum = beam_logprobs_sum.gather(1, beam_ix) + \
89
+ logprobs.reshape(batch_size, -1).gather(1, ix)
90
+ assert (beam_logprobs_sum == ys).all()
91
+ _tmp_beam_logprobs = unaug_logprobs[state_ix].reshape(batch_size, -1, vocab_size)
92
+ beam_logprobs = unaug_logprobs.reshape(batch_size, -1, vocab_size).gather(1,
93
+ beam_ix.unsqueeze(-1).expand(-1,
94
+ -1,
95
+ vocab_size)) # NxbxV
96
+ assert (_tmp_beam_logprobs == beam_logprobs).all()
97
+ beam_seq_logprobs = torch.cat([
98
+ beam_seq_logprobs,
99
+ beam_logprobs.reshape(batch_size, -1, 1, vocab_size)], 2)
100
+
101
+ new_state = [None for _ in state]
102
+ for _ix in range(len(new_state)):
103
+ # copy over state in previous beam q to new beam at vix
104
+ new_state[_ix] = state[_ix][:, state_ix]
105
+ state = new_state
106
+ return beam_seq, beam_seq_logprobs, beam_logprobs_sum, state
107
+
108
+ # Start diverse_beam_search
109
+ opt = kwargs['opt']
110
+ temperature = opt.get('temperature', 1) # This should not affect beam search, but will affect dbs
111
+ beam_size = opt.get('beam_size', 10)
112
+ group_size = opt.get('group_size', 1)
113
+ diversity_lambda = opt.get('diversity_lambda', 0.5)
114
+ decoding_constraint = opt.get('decoding_constraint', 0)
115
+ suppress_UNK = opt.get('suppress_UNK', 0)
116
+ length_penalty = utils.penalty_builder(opt.get('length_penalty', ''))
117
+ bdash = beam_size // group_size # beam per group
118
+
119
+ batch_size = init_logprobs.shape[0]
120
+ device = init_logprobs.device
121
+ # INITIALIZATIONS
122
+ beam_seq_table = [torch.LongTensor(batch_size, bdash, 0).to(device) for _ in range(group_size)]
123
+ beam_seq_logprobs_table = [torch.FloatTensor(batch_size, bdash, 0, self.vocab_size + 1).to(device) for _ in
124
+ range(group_size)]
125
+ beam_logprobs_sum_table = [torch.zeros(batch_size, bdash).to(device) for _ in range(group_size)]
126
+
127
+ # logprobs # logprobs predicted in last time step, shape (beam_size, vocab_size+1)
128
+ done_beams_table = [[[] for __ in range(group_size)] for _ in range(batch_size)]
129
+ state_table = [[_.clone() for _ in init_state] for _ in range(group_size)]
130
+ logprobs_table = [init_logprobs.clone() for _ in range(group_size)]
131
+ # END INIT
132
+
133
+ # Chunk elements in the args
134
+ args = list(args)
135
+ args = utils.split_tensors(group_size, args) # For each arg, turn (Bbg)x... to (Bb)x(g)x...
136
+ if self.__class__.__name__ == 'AttEnsemble':
137
+ args = [[[args[j][i][k] for i in range(len(self.models))] for j in range(len(args))] for k in
138
+ range(group_size)] # group_name, arg_name, model_name
139
+ else:
140
+ args = [[args[i][j] for i in range(len(args))] for j in range(group_size)]
141
+
142
+ for t in range(self.max_seq_length + group_size - 1):
143
+ for divm in range(group_size):
144
+ if t >= divm and t <= self.max_seq_length + divm - 1:
145
+ # add diversity
146
+ logprobs = logprobs_table[divm]
147
+ # suppress previous word
148
+ if decoding_constraint and t - divm > 0:
149
+ logprobs.scatter_(1, beam_seq_table[divm][:, :, t - divm - 1].reshape(-1, 1).to(device),
150
+ float('-inf'))
151
+ # suppress UNK tokens in the decoding
152
+ if suppress_UNK and hasattr(self, 'vocab') and self.vocab[str(logprobs.size(1) - 1)] == 'UNK':
153
+ logprobs[:, logprobs.size(1) - 1] = logprobs[:, logprobs.size(1) - 1] - 1000
154
+ # diversity is added here
155
+ # the function directly modifies the logprobs values and hence, we need to return
156
+ # the unaugmented ones for sorting the candidates in the end. # for historical
157
+ # reasons :-)
158
+ logprobs, unaug_logprobs = add_diversity(beam_seq_table, logprobs, t, divm, diversity_lambda, bdash)
159
+
160
+ # infer new beams
161
+ beam_seq_table[divm], \
162
+ beam_seq_logprobs_table[divm], \
163
+ beam_logprobs_sum_table[divm], \
164
+ state_table[divm] = beam_step(logprobs,
165
+ unaug_logprobs,
166
+ bdash,
167
+ t - divm,
168
+ beam_seq_table[divm],
169
+ beam_seq_logprobs_table[divm],
170
+ beam_logprobs_sum_table[divm],
171
+ state_table[divm])
172
+
173
+ # if time's up... or if end token is reached then copy beams
174
+ for b in range(batch_size):
175
+ is_end = beam_seq_table[divm][b, :, t - divm] == self.eos_idx
176
+ assert beam_seq_table[divm].shape[-1] == t - divm + 1
177
+ if t == self.max_seq_length + divm - 1:
178
+ is_end.fill_(1)
179
+ for vix in range(bdash):
180
+ if is_end[vix]:
181
+ final_beam = {
182
+ 'seq': beam_seq_table[divm][b, vix].clone(),
183
+ 'logps': beam_seq_logprobs_table[divm][b, vix].clone(),
184
+ 'unaug_p': beam_seq_logprobs_table[divm][b, vix].sum().item(),
185
+ 'p': beam_logprobs_sum_table[divm][b, vix].item()
186
+ }
187
+ final_beam['p'] = length_penalty(t - divm + 1, final_beam['p'])
188
+ done_beams_table[b][divm].append(final_beam)
189
+ beam_logprobs_sum_table[divm][b, is_end] -= 1000
190
+
191
+ # move the current group one step forward in time
192
+
193
+ it = beam_seq_table[divm][:, :, t - divm].reshape(-1)
194
+ logprobs_table[divm], state_table[divm] = self.get_logprobs_state(it.cuda(), *(
195
+ args[divm] + [state_table[divm]]))
196
+ logprobs_table[divm] = F.log_softmax(logprobs_table[divm] / temperature, dim=-1)
197
+
198
+ # all beams are sorted by their log-probabilities
199
+ done_beams_table = [[sorted(done_beams_table[b][i], key=lambda x: -x['p'])[:bdash] for i in range(group_size)]
200
+ for b in range(batch_size)]
201
+ done_beams = [sum(_, []) for _ in done_beams_table]
202
+ return done_beams
203
+
204
+ def old_beam_search(self, init_state, init_logprobs, *args, **kwargs):
205
+
206
+ # function computes the similarity score to be augmented
207
+ def add_diversity(beam_seq_table, logprobsf, t, divm, diversity_lambda, bdash):
208
+ local_time = t - divm
209
+ unaug_logprobsf = logprobsf.clone()
210
+ for prev_choice in range(divm):
211
+ prev_decisions = beam_seq_table[prev_choice][local_time]
212
+ for sub_beam in range(bdash):
213
+ for prev_labels in range(bdash):
214
+ logprobsf[sub_beam][prev_decisions[prev_labels]] = logprobsf[sub_beam][prev_decisions[
215
+ prev_labels]] - diversity_lambda
216
+ return unaug_logprobsf
217
+
218
+ # does one step of classical beam search
219
+
220
+ def beam_step(logprobsf, unaug_logprobsf, beam_size, t, beam_seq, beam_seq_logprobs, beam_logprobs_sum, state):
221
+ # INPUTS:
222
+ # logprobsf: probabilities augmented after diversity
223
+ # beam_size: obvious
224
+ # t : time instant
225
+ # beam_seq : tensor contanining the beams
226
+ # beam_seq_logprobs: tensor contanining the beam logprobs
227
+ # beam_logprobs_sum: tensor contanining joint logprobs
228
+ # OUPUTS:
229
+ # beam_seq : tensor containing the word indices of the decoded captions
230
+ # beam_seq_logprobs : log-probability of each decision made, same size as beam_seq
231
+ # beam_logprobs_sum : joint log-probability of each beam
232
+
233
+ ys, ix = torch.sort(logprobsf, 1, True)
234
+ candidates = []
235
+ cols = min(beam_size, ys.size(1))
236
+ rows = beam_size
237
+ if t == 0:
238
+ rows = 1
239
+ for c in range(cols): # for each column (word, essentially)
240
+ for q in range(rows): # for each beam expansion
241
+ # compute logprob of expanding beam q with word in (sorted) position c
242
+ local_logprob = ys[q, c].item()
243
+ candidate_logprob = beam_logprobs_sum[q] + local_logprob
244
+ # local_unaug_logprob = unaug_logprobsf[q,ix[q,c]]
245
+ candidates.append({'c': ix[q, c], 'q': q, 'p': candidate_logprob, 'r': unaug_logprobsf[q]})
246
+ candidates = sorted(candidates, key=lambda x: -x['p'])
247
+
248
+ new_state = [_.clone() for _ in state]
249
+ # beam_seq_prev, beam_seq_logprobs_prev
250
+ if t >= 1:
251
+ # we''ll need these as reference when we fork beams around
252
+ beam_seq_prev = beam_seq[:t].clone()
253
+ beam_seq_logprobs_prev = beam_seq_logprobs[:t].clone()
254
+ for vix in range(beam_size):
255
+ v = candidates[vix]
256
+ # fork beam index q into index vix
257
+ if t >= 1:
258
+ beam_seq[:t, vix] = beam_seq_prev[:, v['q']]
259
+ beam_seq_logprobs[:t, vix] = beam_seq_logprobs_prev[:, v['q']]
260
+ # rearrange recurrent states
261
+ for state_ix in range(len(new_state)):
262
+ # copy over state in previous beam q to new beam at vix
263
+ new_state[state_ix][:, vix] = state[state_ix][:, v['q']] # dimension one is time step
264
+ # append new end terminal at the end of this beam
265
+ beam_seq[t, vix] = v['c'] # c'th word is the continuation
266
+ beam_seq_logprobs[t, vix] = v['r'] # the raw logprob here
267
+ beam_logprobs_sum[vix] = v['p'] # the new (sum) logprob along this beam
268
+ state = new_state
269
+ return beam_seq, beam_seq_logprobs, beam_logprobs_sum, state, candidates
270
+
271
+ # Start diverse_beam_search
272
+ opt = kwargs['opt']
273
+ temperature = opt.get('temperature', 1) # This should not affect beam search, but will affect dbs
274
+ beam_size = opt.get('beam_size', 10)
275
+ group_size = opt.get('group_size', 1)
276
+ diversity_lambda = opt.get('diversity_lambda', 0.5)
277
+ decoding_constraint = opt.get('decoding_constraint', 0)
278
+ suppress_UNK = opt.get('suppress_UNK', 0)
279
+ length_penalty = utils.penalty_builder(opt.get('length_penalty', ''))
280
+ bdash = beam_size // group_size # beam per group
281
+
282
+ # INITIALIZATIONS
283
+ beam_seq_table = [torch.LongTensor(self.max_seq_length, bdash).zero_() for _ in range(group_size)]
284
+ beam_seq_logprobs_table = [torch.FloatTensor(self.max_seq_length, bdash, self.vocab_size + 1).zero_() for _ in
285
+ range(group_size)]
286
+ beam_logprobs_sum_table = [torch.zeros(bdash) for _ in range(group_size)]
287
+
288
+ # logprobs # logprobs predicted in last time step, shape (beam_size, vocab_size+1)
289
+ done_beams_table = [[] for _ in range(group_size)]
290
+ # state_table = [list(torch.unbind(_)) for _ in torch.stack(init_state).chunk(group_size, 2)]
291
+ state_table = list(zip(*[_.chunk(group_size, 1) for _ in init_state]))
292
+ logprobs_table = list(init_logprobs.chunk(group_size, 0))
293
+ # END INIT
294
+
295
+ # Chunk elements in the args
296
+ args = list(args)
297
+ if self.__class__.__name__ == 'AttEnsemble':
298
+ args = [[_.chunk(group_size) if _ is not None else [None] * group_size for _ in args_] for args_ in
299
+ args] # arg_name, model_name, group_name
300
+ args = [[[args[j][i][k] for i in range(len(self.models))] for j in range(len(args))] for k in
301
+ range(group_size)] # group_name, arg_name, model_name
302
+ else:
303
+ args = [_.chunk(group_size) if _ is not None else [None] * group_size for _ in args]
304
+ args = [[args[i][j] for i in range(len(args))] for j in range(group_size)]
305
+
306
+ for t in range(self.max_seq_length + group_size - 1):
307
+ for divm in range(group_size):
308
+ if t >= divm and t <= self.max_seq_length + divm - 1:
309
+ # add diversity
310
+ logprobsf = logprobs_table[divm].float()
311
+ # suppress previous word
312
+ if decoding_constraint and t - divm > 0:
313
+ logprobsf.scatter_(1, beam_seq_table[divm][t - divm - 1].unsqueeze(1).cuda(), float('-inf'))
314
+ # suppress UNK tokens in the decoding
315
+ if suppress_UNK and hasattr(self, 'vocab') and self.vocab[str(logprobsf.size(1) - 1)] == 'UNK':
316
+ logprobsf[:, logprobsf.size(1) - 1] = logprobsf[:, logprobsf.size(1) - 1] - 1000
317
+ # diversity is added here
318
+ # the function directly modifies the logprobsf values and hence, we need to return
319
+ # the unaugmented ones for sorting the candidates in the end. # for historical
320
+ # reasons :-)
321
+ unaug_logprobsf = add_diversity(beam_seq_table, logprobsf, t, divm, diversity_lambda, bdash)
322
+
323
+ # infer new beams
324
+ beam_seq_table[divm], \
325
+ beam_seq_logprobs_table[divm], \
326
+ beam_logprobs_sum_table[divm], \
327
+ state_table[divm], \
328
+ candidates_divm = beam_step(logprobsf,
329
+ unaug_logprobsf,
330
+ bdash,
331
+ t - divm,
332
+ beam_seq_table[divm],
333
+ beam_seq_logprobs_table[divm],
334
+ beam_logprobs_sum_table[divm],
335
+ state_table[divm])
336
+
337
+ # if time's up... or if end token is reached then copy beams
338
+ for vix in range(bdash):
339
+ if beam_seq_table[divm][t - divm, vix] == self.eos_idx or t == self.max_seq_length + divm - 1:
340
+ final_beam = {
341
+ 'seq': beam_seq_table[divm][:, vix].clone(),
342
+ 'logps': beam_seq_logprobs_table[divm][:, vix].clone(),
343
+ 'unaug_p': beam_seq_logprobs_table[divm][:, vix].sum().item(),
344
+ 'p': beam_logprobs_sum_table[divm][vix].item()
345
+ }
346
+ final_beam['p'] = length_penalty(t - divm + 1, final_beam['p'])
347
+ done_beams_table[divm].append(final_beam)
348
+ # don't continue beams from finished sequences
349
+ beam_logprobs_sum_table[divm][vix] = -1000
350
+
351
+ # move the current group one step forward in time
352
+
353
+ it = beam_seq_table[divm][t - divm]
354
+ logprobs_table[divm], state_table[divm] = self.get_logprobs_state(it.cuda(), *(
355
+ args[divm] + [state_table[divm]]))
356
+ logprobs_table[divm] = F.log_softmax(logprobs_table[divm] / temperature, dim=-1)
357
+
358
+ # all beams are sorted by their log-probabilities
359
+ done_beams_table = [sorted(done_beams_table[i], key=lambda x: -x['p'])[:bdash] for i in range(group_size)]
360
+ done_beams = sum(done_beams_table, [])
361
+ return done_beams
362
+
363
+ def sample_next_word(self, logprobs, sample_method, temperature):
364
+ if sample_method == 'greedy':
365
+ sampleLogprobs, it = torch.max(logprobs.data, 1)
366
+ it = it.view(-1).long()
367
+ elif sample_method == 'gumbel': # gumbel softmax
368
+ def sample_gumbel(shape, eps=1e-20):
369
+ U = torch.rand(shape).cuda()
370
+ return -torch.log(-torch.log(U + eps) + eps)
371
+
372
+ def gumbel_softmax_sample(logits, temperature):
373
+ y = logits + sample_gumbel(logits.size())
374
+ return F.log_softmax(y / temperature, dim=-1)
375
+
376
+ _logprobs = gumbel_softmax_sample(logprobs, temperature)
377
+ _, it = torch.max(_logprobs.data, 1)
378
+ sampleLogprobs = logprobs.gather(1, it.unsqueeze(1)) # gather the logprobs at sampled positions
379
+ else:
380
+ logprobs = logprobs / temperature
381
+ if sample_method.startswith('top'): # topk sampling
382
+ top_num = float(sample_method[3:])
383
+ if 0 < top_num < 1:
384
+ # nucleus sampling from # The Curious Case of Neural Text Degeneration
385
+ probs = F.softmax(logprobs, dim=1)
386
+ sorted_probs, sorted_indices = torch.sort(probs, descending=True, dim=1)
387
+ _cumsum = sorted_probs.cumsum(1)
388
+ mask = _cumsum < top_num
389
+ mask = torch.cat([torch.ones_like(mask[:, :1]), mask[:, :-1]], 1)
390
+ sorted_probs = sorted_probs * mask.float()
391
+ sorted_probs = sorted_probs / sorted_probs.sum(1, keepdim=True)
392
+ logprobs.scatter_(1, sorted_indices, sorted_probs.log())
393
+ else:
394
+ the_k = int(top_num)
395
+ tmp = torch.empty_like(logprobs).fill_(float('-inf'))
396
+ topk, indices = torch.topk(logprobs, the_k, dim=1)
397
+ tmp = tmp.scatter(1, indices, topk)
398
+ logprobs = tmp
399
+ it = torch.distributions.Categorical(logits=logprobs.detach()).sample()
400
+ sampleLogprobs = logprobs.gather(1, it.unsqueeze(1)) # gather the logprobs at sampled positions
401
+ return it, sampleLogprobs
modules/config.pkl ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:c454e6bddb15af52c82734f1796391bf3a10a6c5533ea095de06f661ebb858bb
3
+ size 1744
modules/dataloader.py ADDED
@@ -0,0 +1,59 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pdb
2
+
3
+ import torch
4
+ from torch.utils.data import DataLoader
5
+ from .dataset import IuxrayMultiImageDataset, MimiccxrSingleImageDataset
6
+ from medclip import MedCLIPProcessor
7
+ import numpy as np
8
+
9
+ class R2DataLoader(DataLoader):
10
+ def __init__(self, args, tokenizer, split, shuffle):
11
+ self.args = args
12
+ self.dataset_name = args.dataset
13
+ self.batch_size = args.bs
14
+ self.shuffle = shuffle
15
+ self.num_workers = args.num_workers
16
+ self.tokenizer = tokenizer
17
+ self.split = split
18
+ self.processor = MedCLIPProcessor()
19
+
20
+ if self.dataset_name == 'iu_xray':
21
+ self.dataset = IuxrayMultiImageDataset(self.args, self.tokenizer, self.split, self.processor)
22
+ else:
23
+ self.dataset = MimiccxrSingleImageDataset(self.args, self.tokenizer, self.split, self.processor)
24
+
25
+ self.init_kwargs = {
26
+ 'dataset': self.dataset,
27
+ 'batch_size': self.batch_size,
28
+ 'shuffle': self.shuffle,
29
+ 'collate_fn': self.collate_fn,
30
+ 'num_workers': self.num_workers
31
+ }
32
+ super().__init__(**self.init_kwargs)
33
+
34
+ @staticmethod
35
+ def collate_fn(data):
36
+ image_id_batch, image_batch, report_ids_batch, report_masks_batch, processor_ids_batch, processor_mask_batch, seq_lengths_batch, processor_lenghts_batch = zip(*data)
37
+ image_batch = torch.stack(image_batch, 0)
38
+
39
+ max_seq_length = max(seq_lengths_batch)
40
+ target_batch = np.zeros((len(report_ids_batch), max_seq_length), dtype=int)
41
+ target_masks_batch = np.zeros((len(report_ids_batch), max_seq_length), dtype=int)
42
+
43
+ max_processor_length = max(processor_lenghts_batch)
44
+ target_processor_batch = np.zeros((len(processor_ids_batch), max_processor_length), dtype=int)
45
+ target_processor_mask_batch = np.zeros((len(processor_mask_batch), max_processor_length), dtype=int)
46
+
47
+ for i, report_ids in enumerate(report_ids_batch):
48
+ target_batch[i, :len(report_ids)] = report_ids
49
+
50
+ for i, report_masks in enumerate(report_masks_batch):
51
+ target_masks_batch[i, :len(report_masks)] = report_masks
52
+
53
+ for i, report_ids in enumerate(processor_ids_batch):
54
+ target_processor_batch[i, :len(report_ids)] = report_ids
55
+
56
+ for i, report_masks in enumerate(processor_mask_batch):
57
+ target_processor_mask_batch[i, :len(report_masks)] = report_masks
58
+
59
+ return image_id_batch, image_batch, torch.LongTensor(target_batch), torch.FloatTensor(target_masks_batch), torch.FloatTensor(target_processor_batch), torch.FloatTensor(target_processor_mask_batch)
modules/dataloaders.py ADDED
@@ -0,0 +1,62 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import numpy as np
3
+ from torchvision import transforms
4
+ from torch.utils.data import DataLoader
5
+ from .datasets import IuxrayMultiImageDataset, MimiccxrSingleImageDataset
6
+
7
+
8
+ class R2DataLoader(DataLoader):
9
+ def __init__(self, args, tokenizer, split, shuffle):
10
+ self.args = args
11
+ self.dataset_name = args.dataset_name
12
+ self.batch_size = args.batch_size
13
+ self.shuffle = shuffle
14
+ self.num_workers = args.num_workers
15
+ self.tokenizer = tokenizer
16
+ self.split = split
17
+
18
+ if split == 'train':
19
+ self.transform = transforms.Compose([
20
+ transforms.Resize(256),
21
+ transforms.RandomCrop(224),
22
+ transforms.RandomHorizontalFlip(),
23
+ transforms.ToTensor(),
24
+ transforms.Normalize((0.485, 0.456, 0.406),
25
+ (0.229, 0.224, 0.225))])
26
+ else:
27
+ self.transform = transforms.Compose([
28
+ transforms.Resize((224, 224)),
29
+ transforms.ToTensor(),
30
+ transforms.Normalize((0.485, 0.456, 0.406),
31
+ (0.229, 0.224, 0.225))])
32
+
33
+ if self.dataset_name == 'iu_xray':
34
+ self.dataset = IuxrayMultiImageDataset(self.args, self.tokenizer, self.split, transform=self.transform)
35
+ else:
36
+ self.dataset = MimiccxrSingleImageDataset(self.args, self.tokenizer, self.split, transform=self.transform)
37
+
38
+ self.init_kwargs = {
39
+ 'dataset': self.dataset,
40
+ 'batch_size': self.batch_size,
41
+ 'shuffle': self.shuffle,
42
+ 'collate_fn': self.collate_fn,
43
+ 'num_workers': self.num_workers
44
+ }
45
+ super().__init__(**self.init_kwargs)
46
+
47
+ @staticmethod
48
+ def collate_fn(data):
49
+ images_id, images, reports_ids, reports_masks, seq_lengths = zip(*data)
50
+ images = torch.stack(images, 0)
51
+ max_seq_length = max(seq_lengths)
52
+
53
+ targets = np.zeros((len(reports_ids), max_seq_length), dtype=int)
54
+ targets_masks = np.zeros((len(reports_ids), max_seq_length), dtype=int)
55
+
56
+ for i, report_ids in enumerate(reports_ids):
57
+ targets[i, :len(report_ids)] = report_ids
58
+
59
+ for i, report_masks in enumerate(reports_masks):
60
+ targets_masks[i, :len(report_masks)] = report_masks
61
+
62
+ return images_id, images, torch.LongTensor(targets), torch.FloatTensor(targets_masks)
modules/dataset.py ADDED
@@ -0,0 +1,68 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from PIL import Image
3
+ import json
4
+ from torch.utils.data import Dataset
5
+ import numpy as np
6
+ import torch
7
+
8
+
9
+ class BaseDataset(Dataset):
10
+ def __init__(self, args, tokenizer, split, processor):
11
+ self.image_dir = args.image_dir
12
+ self.ann_path = args.json_path
13
+ self.max_seq_length = args.max_seq_length
14
+ self.split = split
15
+ self.tokenizer = tokenizer
16
+ self.ann = json.loads(open(self.ann_path, 'r').read())
17
+ self.examples = self.ann[self.split]
18
+ self.processor = processor
19
+
20
+ def preprocess_text(self, text):
21
+ ids = self.tokenizer(text)[:self.max_seq_length]
22
+ mask = [1] * len(ids)
23
+ text_inputs = self.processor(text=text, return_tensors="pt",truncation=True, padding=False, max_length=self.max_seq_length)
24
+ processor_ids = text_inputs['input_ids'].squeeze(0).tolist()
25
+ processor_mask = text_inputs['attention_mask'].squeeze(0).tolist()
26
+ return ids, mask, processor_ids, processor_mask
27
+
28
+ def __len__(self):
29
+ return len(self.examples)
30
+
31
+
32
+ class IuxrayMultiImageDataset(BaseDataset):
33
+ def __getitem__(self, idx):
34
+ example = self.examples[idx]
35
+ report = example['report']
36
+ report_ids, report_masks, processor_ids, processor_mask = self.preprocess_text(report)
37
+
38
+ image_id = example['id']
39
+ image_path = example['image_path']
40
+ image_1 = Image.open(os.path.join(self.image_dir, image_path[0])).convert('RGB')
41
+ image_2 = Image.open(os.path.join(self.image_dir, image_path[1])).convert('RGB')
42
+ # MedCLIP processing
43
+ image_inputs_1 = self.processor(images=image_1, return_tensors="pt")
44
+ image_inputs_2 = self.processor(images=image_2, return_tensors="pt")
45
+ image = torch.stack((image_inputs_1.pixel_values[0], image_inputs_2.pixel_values[0]), 0)
46
+
47
+ seq_length = len(report_ids)
48
+ processor_length = len(processor_ids)
49
+ sample = (image_id, image, report_ids, report_masks, processor_ids, processor_mask, seq_length, processor_length)
50
+ return sample
51
+
52
+
53
+ class MimiccxrSingleImageDataset(BaseDataset):
54
+ def __getitem__(self, idx):
55
+ example = self.examples[idx]
56
+ report = example['report']
57
+ report_ids, report_masks, processor_ids, processor_mask = self.preprocess_text(report)
58
+
59
+ image_id = example['id']
60
+ image_path = example['image_path']
61
+ image = Image.open(os.path.join(self.image_dir, image_path[0])).convert('RGB')
62
+ image_inputs = self.processor(images=image, return_tensors="pt")
63
+ image = image_inputs.pixel_values[0]
64
+
65
+ seq_length = len(report_ids)
66
+ processor_length = len(processor_ids)
67
+ sample = (image_id, image, report_ids, report_masks, processor_ids, processor_mask, seq_length, processor_length)
68
+ return sample
modules/datasets.py ADDED
@@ -0,0 +1,57 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import json
3
+ import torch
4
+ from PIL import Image
5
+ from torch.utils.data import Dataset
6
+
7
+
8
+ class BaseDataset(Dataset):
9
+ def __init__(self, args, tokenizer, split, transform=None):
10
+ self.image_dir = args.image_dir
11
+ self.ann_path = args.ann_path
12
+ self.max_seq_length = args.max_seq_length
13
+ self.split = split
14
+ self.tokenizer = tokenizer
15
+ self.transform = transform
16
+ self.ann = json.loads(open(self.ann_path, 'r').read())
17
+
18
+ self.examples = self.ann[self.split]
19
+ for i in range(len(self.examples)):
20
+ self.examples[i]['ids'] = tokenizer(self.examples[i]['report'])[:self.max_seq_length]
21
+ self.examples[i]['mask'] = [1] * len(self.examples[i]['ids'])
22
+
23
+ def __len__(self):
24
+ return len(self.examples)
25
+
26
+
27
+ class IuxrayMultiImageDataset(BaseDataset):
28
+ def __getitem__(self, idx):
29
+ example = self.examples[idx]
30
+ image_id = example['id']
31
+ image_path = example['image_path']
32
+ image_1 = Image.open(os.path.join(self.image_dir, image_path[0])).convert('RGB')
33
+ image_2 = Image.open(os.path.join(self.image_dir, image_path[1])).convert('RGB')
34
+ if self.transform is not None:
35
+ image_1 = self.transform(image_1)
36
+ image_2 = self.transform(image_2)
37
+ image = torch.stack((image_1, image_2), 0)
38
+ report_ids = example['ids']
39
+ report_masks = example['mask']
40
+ seq_length = len(report_ids)
41
+ sample = (image_id, image, report_ids, report_masks, seq_length)
42
+ return sample
43
+
44
+
45
+ class MimiccxrSingleImageDataset(BaseDataset):
46
+ def __getitem__(self, idx):
47
+ example = self.examples[idx]
48
+ image_id = example['id']
49
+ image_path = example['image_path']
50
+ image = Image.open(os.path.join(self.image_dir, image_path[0])).convert('RGB')
51
+ if self.transform is not None:
52
+ image = self.transform(image)
53
+ report_ids = example['ids']
54
+ report_masks = example['mask']
55
+ seq_length = len(report_ids)
56
+ sample = (image_id, image, report_ids, report_masks, seq_length)
57
+ return sample
modules/decoder.py ADDED
@@ -0,0 +1,50 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import torch
3
+ import torch.nn as nn
4
+ import pickle
5
+ from typing import Tuple
6
+ from transformers import GPT2LMHeadModel
7
+ from .att_models import AttModel
8
+ import pdb
9
+
10
+ class MLP(nn.Module):
11
+
12
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
13
+ return self.model(x)
14
+
15
+ def __init__(self, sizes: Tuple[int, ...], bias=True, act=nn.Tanh):
16
+ super(MLP, self).__init__()
17
+ layers = []
18
+ for i in range(len(sizes) - 1):
19
+ layers.append(nn.Linear(sizes[i], sizes[i + 1], bias=bias))
20
+ if i < len(sizes) - 2:
21
+ layers.append(act())
22
+ self.model = nn.Sequential(*layers)
23
+
24
+ class DeCap(AttModel):
25
+
26
+ def __init__(self, args, tokenizer):
27
+ super(DeCap, self).__init__(args, tokenizer)
28
+
29
+ # decoder: 4 layers transformer with 4 attention heads
30
+ # the decoder is not pretrained
31
+ with open('./decoder_config/decoder_config.pkl', 'rb') as f:
32
+ config = pickle.load(f)
33
+ # Change the parameters you need
34
+ config.vocab_size = tokenizer.get_vocab_size()
35
+ config.bos_token_id = tokenizer.bos_token_id
36
+ config.eos_token_id = tokenizer.eos_token_id
37
+ self.decoder = GPT2LMHeadModel(config)
38
+ self.embedding_size = self.decoder.transformer.wte.weight.shape[1]
39
+ self.prefix_size = 512
40
+ self.clip_project = MLP((self.prefix_size, self.embedding_size))
41
+
42
+ def _forward(self, clip_features, gpt_tokens):
43
+
44
+ embedding_text = self.decoder.transformer.wte(gpt_tokens)
45
+ embedding_clip = self.clip_project(clip_features)
46
+ embedding_clip = embedding_clip.reshape(-1, 1, self.embedding_size)
47
+ embedding_cat = torch.cat([embedding_clip, embedding_text], dim=1)
48
+ out = self.decoder(inputs_embeds=embedding_cat)
49
+ return out
50
+
modules/encoder_decoder.py ADDED
@@ -0,0 +1,391 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import absolute_import
2
+ from __future__ import division
3
+ from __future__ import print_function
4
+
5
+ import copy
6
+ import math
7
+
8
+ import numpy as np
9
+ import torch
10
+ import torch.nn as nn
11
+ import torch.nn.functional as F
12
+
13
+ from .att_model import pack_wrapper, AttModel
14
+
15
+
16
+ def clones(module, N):
17
+ return nn.ModuleList([copy.deepcopy(module) for _ in range(N)])
18
+
19
+
20
+ def attention(query, key, value, mask=None, dropout=None):
21
+ d_k = query.size(-1)
22
+ scores = torch.matmul(query, key.transpose(-2, -1)) / math.sqrt(d_k)
23
+ if mask is not None:
24
+ scores = scores.masked_fill(mask == 0, -1e9)
25
+ p_attn = F.softmax(scores, dim=-1)
26
+ if dropout is not None:
27
+ p_attn = dropout(p_attn)
28
+ return torch.matmul(p_attn, value), p_attn
29
+
30
+
31
+ def subsequent_mask(size):
32
+ attn_shape = (1, size, size)
33
+ subsequent_mask = np.triu(np.ones(attn_shape), k=1).astype('uint8')
34
+ return torch.from_numpy(subsequent_mask) == 0
35
+
36
+
37
+ class Transformer(nn.Module):
38
+ def __init__(self, encoder, decoder, src_embed, tgt_embed, rm):
39
+ super(Transformer, self).__init__()
40
+ self.encoder = encoder
41
+ self.decoder = decoder
42
+ self.src_embed = src_embed
43
+ self.tgt_embed = tgt_embed
44
+ self.rm = rm
45
+
46
+ def forward(self, src, tgt, src_mask, tgt_mask):
47
+ return self.decode(self.encode(src, src_mask), src_mask, tgt, tgt_mask)
48
+
49
+ def encode(self, src, src_mask):
50
+ return self.encoder(self.src_embed(src), src_mask)
51
+
52
+ def decode(self, hidden_states, src_mask, tgt, tgt_mask):
53
+ memory = self.rm.init_memory(hidden_states.size(0)).to(hidden_states)
54
+ memory = self.rm(self.tgt_embed(tgt), memory)
55
+ return self.decoder(self.tgt_embed(tgt), hidden_states, src_mask, tgt_mask, memory)
56
+
57
+
58
+ class Encoder(nn.Module):
59
+ def __init__(self, layer, N):
60
+ super(Encoder, self).__init__()
61
+ self.layers = clones(layer, N)
62
+ self.norm = LayerNorm(layer.d_model)
63
+
64
+ def forward(self, x, mask):
65
+ for layer in self.layers:
66
+ x = layer(x, mask)
67
+ return self.norm(x)
68
+
69
+
70
+ class EncoderLayer(nn.Module):
71
+ def __init__(self, d_model, self_attn, feed_forward, dropout):
72
+ super(EncoderLayer, self).__init__()
73
+ self.self_attn = self_attn
74
+ self.feed_forward = feed_forward
75
+ self.sublayer = clones(SublayerConnection(d_model, dropout), 2)
76
+ self.d_model = d_model
77
+
78
+ def forward(self, x, mask):
79
+ x = self.sublayer[0](x, lambda x: self.self_attn(x, x, x, mask))
80
+ return self.sublayer[1](x, self.feed_forward)
81
+
82
+
83
+ class SublayerConnection(nn.Module):
84
+ def __init__(self, d_model, dropout):
85
+ super(SublayerConnection, self).__init__()
86
+ self.norm = LayerNorm(d_model)
87
+ self.dropout = nn.Dropout(dropout)
88
+
89
+ def forward(self, x, sublayer):
90
+ return x + self.dropout(sublayer(self.norm(x)))
91
+
92
+
93
+ class LayerNorm(nn.Module):
94
+ def __init__(self, features, eps=1e-6):
95
+ super(LayerNorm, self).__init__()
96
+ self.gamma = nn.Parameter(torch.ones(features))
97
+ self.beta = nn.Parameter(torch.zeros(features))
98
+ self.eps = eps
99
+
100
+ def forward(self, x):
101
+ mean = x.mean(-1, keepdim=True)
102
+ std = x.std(-1, keepdim=True)
103
+ return self.gamma * (x - mean) / (std + self.eps) + self.beta
104
+
105
+
106
+ class Decoder(nn.Module):
107
+ def __init__(self, layer, N):
108
+ super(Decoder, self).__init__()
109
+ self.layers = clones(layer, N)
110
+ self.norm = LayerNorm(layer.d_model)
111
+
112
+ def forward(self, x, hidden_states, src_mask, tgt_mask, memory):
113
+ for layer in self.layers:
114
+ x = layer(x, hidden_states, src_mask, tgt_mask, memory)
115
+ return self.norm(x)
116
+
117
+
118
+ class DecoderLayer(nn.Module):
119
+ def __init__(self, d_model, self_attn, src_attn, feed_forward, dropout, rm_num_slots, rm_d_model):
120
+ super(DecoderLayer, self).__init__()
121
+ self.d_model = d_model
122
+ self.self_attn = self_attn
123
+ self.src_attn = src_attn
124
+ self.feed_forward = feed_forward
125
+ self.sublayer = clones(ConditionalSublayerConnection(d_model, dropout, rm_num_slots, rm_d_model), 3)
126
+
127
+ def forward(self, x, hidden_states, src_mask, tgt_mask, memory):
128
+ m = hidden_states
129
+ x = self.sublayer[0](x, lambda x: self.self_attn(x, x, x, tgt_mask), memory)
130
+ x = self.sublayer[1](x, lambda x: self.src_attn(x, m, m, src_mask), memory)
131
+ return self.sublayer[2](x, self.feed_forward, memory)
132
+
133
+
134
+ class ConditionalSublayerConnection(nn.Module):
135
+ def __init__(self, d_model, dropout, rm_num_slots, rm_d_model):
136
+ super(ConditionalSublayerConnection, self).__init__()
137
+ self.norm = ConditionalLayerNorm(d_model, rm_num_slots, rm_d_model)
138
+ self.dropout = nn.Dropout(dropout)
139
+
140
+ def forward(self, x, sublayer, memory):
141
+ return x + self.dropout(sublayer(self.norm(x, memory)))
142
+
143
+
144
+ class ConditionalLayerNorm(nn.Module):
145
+ def __init__(self, d_model, rm_num_slots, rm_d_model, eps=1e-6):
146
+ super(ConditionalLayerNorm, self).__init__()
147
+ self.gamma = nn.Parameter(torch.ones(d_model))
148
+ self.beta = nn.Parameter(torch.zeros(d_model))
149
+ self.rm_d_model = rm_d_model
150
+ self.rm_num_slots = rm_num_slots
151
+ self.eps = eps
152
+
153
+ self.mlp_gamma = nn.Sequential(nn.Linear(rm_num_slots * rm_d_model, d_model),
154
+ nn.ReLU(inplace=True),
155
+ nn.Linear(rm_d_model, rm_d_model))
156
+
157
+ self.mlp_beta = nn.Sequential(nn.Linear(rm_num_slots * rm_d_model, d_model),
158
+ nn.ReLU(inplace=True),
159
+ nn.Linear(d_model, d_model))
160
+
161
+ for m in self.modules():
162
+ if isinstance(m, nn.Linear):
163
+ nn.init.xavier_uniform_(m.weight)
164
+ nn.init.constant_(m.bias, 0.1)
165
+
166
+ def forward(self, x, memory):
167
+ mean = x.mean(-1, keepdim=True)
168
+ std = x.std(-1, keepdim=True)
169
+ delta_gamma = self.mlp_gamma(memory)
170
+ delta_beta = self.mlp_beta(memory)
171
+ gamma_hat = self.gamma.clone()
172
+ beta_hat = self.beta.clone()
173
+ gamma_hat = torch.stack([gamma_hat] * x.size(0), dim=0)
174
+ gamma_hat = torch.stack([gamma_hat] * x.size(1), dim=1)
175
+ beta_hat = torch.stack([beta_hat] * x.size(0), dim=0)
176
+ beta_hat = torch.stack([beta_hat] * x.size(1), dim=1)
177
+ gamma_hat += delta_gamma
178
+ beta_hat += delta_beta
179
+ return gamma_hat * (x - mean) / (std + self.eps) + beta_hat
180
+
181
+
182
+ class MultiHeadedAttention(nn.Module):
183
+ def __init__(self, h, d_model, dropout=0.1):
184
+ super(MultiHeadedAttention, self).__init__()
185
+ assert d_model % h == 0
186
+ self.d_k = d_model // h
187
+ self.h = h
188
+ self.linears = clones(nn.Linear(d_model, d_model), 4)
189
+ self.attn = None
190
+ self.dropout = nn.Dropout(p=dropout)
191
+
192
+ def forward(self, query, key, value, mask=None):
193
+ if mask is not None:
194
+ mask = mask.unsqueeze(1)
195
+ nbatches = query.size(0)
196
+ query, key, value = \
197
+ [l(x).view(nbatches, -1, self.h, self.d_k).transpose(1, 2)
198
+ for l, x in zip(self.linears, (query, key, value))]
199
+
200
+ x, self.attn = attention(query, key, value, mask=mask, dropout=self.dropout)
201
+
202
+ x = x.transpose(1, 2).contiguous().view(nbatches, -1, self.h * self.d_k)
203
+ return self.linears[-1](x)
204
+
205
+
206
+ class PositionwiseFeedForward(nn.Module):
207
+ def __init__(self, d_model, d_ff, dropout=0.1):
208
+ super(PositionwiseFeedForward, self).__init__()
209
+ self.w_1 = nn.Linear(d_model, d_ff)
210
+ self.w_2 = nn.Linear(d_ff, d_model)
211
+ self.dropout = nn.Dropout(dropout)
212
+
213
+ def forward(self, x):
214
+ return self.w_2(self.dropout(F.relu(self.w_1(x))))
215
+
216
+
217
+ class Embeddings(nn.Module):
218
+ def __init__(self, d_model, vocab):
219
+ super(Embeddings, self).__init__()
220
+ self.lut = nn.Embedding(vocab, d_model)
221
+ self.d_model = d_model
222
+
223
+ def forward(self, x):
224
+ return self.lut(x) * math.sqrt(self.d_model)
225
+
226
+
227
+ class PositionalEncoding(nn.Module):
228
+ def __init__(self, d_model, dropout, max_len=5000):
229
+ super(PositionalEncoding, self).__init__()
230
+ self.dropout = nn.Dropout(p=dropout)
231
+
232
+ pe = torch.zeros(max_len, d_model)
233
+ position = torch.arange(0, max_len).unsqueeze(1).float()
234
+ div_term = torch.exp(torch.arange(0, d_model, 2).float() *
235
+ -(math.log(10000.0) / d_model))
236
+ pe[:, 0::2] = torch.sin(position * div_term)
237
+ pe[:, 1::2] = torch.cos(position * div_term)
238
+ pe = pe.unsqueeze(0)
239
+ self.register_buffer('pe', pe)
240
+
241
+ def forward(self, x):
242
+ x = x + self.pe[:, :x.size(1)]
243
+ return self.dropout(x)
244
+
245
+
246
+ class RelationalMemory(nn.Module):
247
+
248
+ def __init__(self, num_slots, d_model, num_heads=1):
249
+ super(RelationalMemory, self).__init__()
250
+ self.num_slots = num_slots
251
+ self.num_heads = num_heads
252
+ self.d_model = d_model
253
+
254
+ self.attn = MultiHeadedAttention(num_heads, d_model)
255
+ self.mlp = nn.Sequential(nn.Linear(self.d_model, self.d_model),
256
+ nn.ReLU(),
257
+ nn.Linear(self.d_model, self.d_model),
258
+ nn.ReLU())
259
+
260
+ self.W = nn.Linear(self.d_model, self.d_model * 2)
261
+ self.U = nn.Linear(self.d_model, self.d_model * 2)
262
+
263
+ def init_memory(self, batch_size):
264
+ memory = torch.stack([torch.eye(self.num_slots)] * batch_size)
265
+ if self.d_model > self.num_slots:
266
+ diff = self.d_model - self.num_slots
267
+ pad = torch.zeros((batch_size, self.num_slots, diff))
268
+ memory = torch.cat([memory, pad], -1)
269
+ elif self.d_model < self.num_slots:
270
+ memory = memory[:, :, :self.d_model]
271
+
272
+ return memory
273
+
274
+ def forward_step(self, input, memory):
275
+ # print('inputinputinputinputinput',input.size())
276
+ # print('memorymemorymemorymemorymemorymemory',memory.size())
277
+
278
+ memory = memory.reshape(-1, self.num_slots, self.d_model)
279
+ # if input.shape[0]!=memory.shape[0]:
280
+ # input=input.repeat(round(memory.shape[0]/input.shape[0]),1)
281
+ q = memory
282
+ k = torch.cat([memory, input.unsqueeze(1)], 1)
283
+ v = torch.cat([memory, input.unsqueeze(1)], 1)
284
+ next_memory = memory + self.attn(q, k, v)
285
+ next_memory = next_memory + self.mlp(next_memory)
286
+
287
+ gates = self.W(input.unsqueeze(1)) + self.U(torch.tanh(memory))
288
+ gates = torch.split(gates, split_size_or_sections=self.d_model, dim=2)
289
+ input_gate, forget_gate = gates
290
+ input_gate = torch.sigmoid(input_gate)
291
+ forget_gate = torch.sigmoid(forget_gate)
292
+
293
+ next_memory = input_gate * torch.tanh(next_memory) + forget_gate * memory
294
+ next_memory = next_memory.reshape(-1, self.num_slots * self.d_model)
295
+
296
+ return next_memory
297
+
298
+ def forward(self, inputs, memory):
299
+ outputs = []
300
+ for i in range(inputs.shape[1]):
301
+ memory = self.forward_step(inputs[:, i], memory)
302
+ outputs.append(memory)
303
+ outputs = torch.stack(outputs, dim=1)
304
+
305
+ return outputs
306
+
307
+
308
+ class EncoderDecoder(AttModel):
309
+
310
+ def make_model(self, tgt_vocab):
311
+ c = copy.deepcopy
312
+ attn = MultiHeadedAttention(self.num_heads, self.d_model)
313
+ ff = PositionwiseFeedForward(self.d_model, self.d_ff, self.dropout)
314
+ position = PositionalEncoding(self.d_model, self.dropout)
315
+ rm = RelationalMemory(num_slots=self.rm_num_slots, d_model=self.rm_d_model, num_heads=self.rm_num_heads)
316
+ model = Transformer(
317
+ Encoder(EncoderLayer(self.d_model, c(attn), c(ff), self.dropout), self.num_layers),
318
+ Decoder(
319
+ DecoderLayer(self.d_model, c(attn), c(attn), c(ff), self.dropout, self.rm_num_slots, self.rm_d_model),
320
+ self.num_layers),
321
+ lambda x: x,
322
+ nn.Sequential(Embeddings(self.d_model, tgt_vocab), c(position)),
323
+ rm)
324
+ for p in model.parameters():
325
+ if p.dim() > 1:
326
+ nn.init.xavier_uniform_(p)
327
+ return model
328
+
329
+ def __init__(self, args, tokenizer):
330
+ super(EncoderDecoder, self).__init__(args, tokenizer)
331
+ self.args = args
332
+ self.num_layers = args.num_layers
333
+ self.d_model = args.d_model
334
+ self.d_ff = args.d_ff
335
+ self.num_heads = args.num_heads
336
+ self.dropout = args.dropout
337
+ self.rm_num_slots = args.rm_num_slots
338
+ self.rm_num_heads = args.rm_num_heads
339
+ self.rm_d_model = args.rm_d_model
340
+
341
+ tgt_vocab = self.vocab_size + 1
342
+
343
+ self.model = self.make_model(tgt_vocab)
344
+ self.logit = nn.Linear(args.d_model, tgt_vocab)
345
+
346
+ def init_hidden(self, bsz):
347
+ return []
348
+
349
+ def _prepare_feature(self, fc_feats, att_feats, att_masks):
350
+
351
+ att_feats, seq, att_masks, seq_mask = self._prepare_feature_forward(att_feats, att_masks)
352
+ memory = self.model.encode(att_feats, att_masks)
353
+
354
+ return fc_feats[..., :1], att_feats[..., :1], memory, att_masks
355
+
356
+ def _prepare_feature_forward(self, att_feats, att_masks=None, seq=None):
357
+ att_feats, att_masks = self.clip_att(att_feats, att_masks)
358
+ att_feats = pack_wrapper(self.att_embed, att_feats, att_masks)
359
+
360
+ if att_masks is None:
361
+ att_masks = att_feats.new_ones(att_feats.shape[:2], dtype=torch.long)
362
+ att_masks = att_masks.unsqueeze(-2)
363
+
364
+ if seq is not None:
365
+ # crop the last one
366
+ seq = seq[:, :-1]
367
+ seq_mask = (seq.data > 0)
368
+ seq_mask[:, 0] += True
369
+
370
+ seq_mask = seq_mask.unsqueeze(-2)
371
+ seq_mask = seq_mask & subsequent_mask(seq.size(-1)).to(seq_mask)
372
+ else:
373
+ seq_mask = None
374
+
375
+ return att_feats, seq, att_masks, seq_mask
376
+
377
+ def _forward(self, fc_feats, att_feats, seq, att_masks=None):
378
+
379
+ att_feats, seq, att_masks, seq_mask = self._prepare_feature_forward(att_feats, att_masks, seq)
380
+ out = self.model(att_feats, seq, att_masks, seq_mask)
381
+ outputs = F.log_softmax(self.logit(out), dim=-1)
382
+ return outputs
383
+
384
+ def core(self, it, fc_feats_ph, att_feats_ph, memory, state, mask):
385
+
386
+ if len(state) == 0:
387
+ ys = it.unsqueeze(1)
388
+ else:
389
+ ys = torch.cat([state[0][0], it.unsqueeze(1)], dim=1)
390
+ out = self.model.decode(memory, mask, ys, subsequent_mask(ys.size(1)).to(memory.device))
391
+ return out[:, -1], [ys.unsqueeze(0)]
modules/loss.py ADDED
@@ -0,0 +1,22 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+
4
+
5
+ class LanguageModelCriterion(nn.Module):
6
+ def __init__(self):
7
+ super(LanguageModelCriterion, self).__init__()
8
+
9
+ def forward(self, input, target, mask):
10
+ # truncate to the same size
11
+ target = target[:, :input.size(1)]
12
+ mask = mask[:, :input.size(1)]
13
+ output = -input.gather(2, target.long().unsqueeze(2)).squeeze(2) * mask
14
+ output = torch.sum(output) / torch.sum(mask)
15
+
16
+ return output
17
+
18
+
19
+ def compute_loss(output, reports_ids, reports_masks):
20
+ criterion = LanguageModelCriterion()
21
+ loss = criterion(output, reports_ids[:, 1:], reports_masks[:, 1:]).mean()
22
+ return loss
modules/metrics.py ADDED
@@ -0,0 +1,33 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from pycocoevalcap.bleu.bleu import Bleu
2
+ from pycocoevalcap.meteor import Meteor
3
+ from pycocoevalcap.rouge import Rouge
4
+
5
+
6
+ def compute_scores(gts, res):
7
+ """
8
+ Performs the MS COCO evaluation using the Python 3 implementation (https://github.com/salaniz/pycocoevalcap)
9
+
10
+ :param gts: Dictionary with the image ids and their gold captions,
11
+ :param res: Dictionary with the image ids ant their generated captions
12
+ :print: Evaluation score (the mean of the scores of all the instances) for each measure
13
+ """
14
+
15
+ # Set up scorers
16
+ scorers = [
17
+ (Bleu(4), ["BLEU_1", "BLEU_2", "BLEU_3", "BLEU_4"]),
18
+ (Meteor(), "METEOR"),
19
+ (Rouge(), "ROUGE_L")
20
+ ]
21
+ eval_res = {}
22
+ # Compute score for each metric
23
+ for scorer, method in scorers:
24
+ try:
25
+ score, scores = scorer.compute_score(gts, res, verbose=0)
26
+ except TypeError:
27
+ score, scores = scorer.compute_score(gts, res)
28
+ if type(method) == list:
29
+ for sc, m in zip(score, method):
30
+ eval_res[m] = sc
31
+ else:
32
+ eval_res[method] = score
33
+ return eval_res
modules/optimizers.py ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+
3
+
4
+ def build_optimizer(args, model):
5
+ ve_params = list(map(id, model.visual_extractor.parameters()))
6
+ ed_params = filter(lambda x: id(x) not in ve_params, model.parameters())
7
+ optimizer = getattr(torch.optim, args.optim)(
8
+ [{'params': model.visual_extractor.parameters(), 'lr': args.lr_ve},
9
+ {'params': ed_params, 'lr': args.lr_ed}],
10
+ weight_decay=args.weight_decay,
11
+ amsgrad=args.amsgrad
12
+ )
13
+ return optimizer
14
+
15
+
16
+ def build_lr_scheduler(args, optimizer):
17
+ lr_scheduler = getattr(torch.optim.lr_scheduler, args.lr_scheduler)(optimizer, args.step_size, args.gamma)
18
+ return lr_scheduler
modules/tester.py ADDED
@@ -0,0 +1,144 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ import os
3
+ from abc import abstractmethod
4
+
5
+ import cv2
6
+ import numpy as np
7
+ import pandas as pd
8
+ import spacy
9
+ import torch
10
+ from tqdm import tqdm
11
+
12
+ from modules.utils import generate_heatmap
13
+
14
+
15
+ class BaseTester(object):
16
+ def __init__(self, model, criterion, metric_ftns, args):
17
+ self.args = args
18
+
19
+ logging.basicConfig(format='%(asctime)s - %(levelname)s - %(name)s - %(message)s',
20
+ datefmt='%m/%d/%Y %H:%M:%S', level=logging.INFO)
21
+ self.logger = logging.getLogger(__name__)
22
+
23
+ # setup GPU device if available, move model into configured device
24
+ self.device, device_ids = self._prepare_device(args.n_gpu)
25
+ self.model = model.to(self.device)
26
+ if len(device_ids) > 1:
27
+ self.model = torch.nn.DataParallel(model, device_ids=device_ids)
28
+
29
+ self.criterion = criterion
30
+ self.metric_ftns = metric_ftns
31
+
32
+ self.epochs = self.args.epochs
33
+ self.save_dir = self.args.save_dir
34
+ if not os.path.exists(self.save_dir):
35
+ os.makedirs(self.save_dir)
36
+
37
+ self._load_checkpoint(args.load)
38
+
39
+ @abstractmethod
40
+ def test(self):
41
+ raise NotImplementedError
42
+
43
+ @abstractmethod
44
+ def plot(self):
45
+ raise NotImplementedError
46
+
47
+ def _prepare_device(self, n_gpu_use):
48
+ n_gpu = torch.cuda.device_count()
49
+ if n_gpu_use > 0 and n_gpu == 0:
50
+ self.logger.warning(
51
+ "Warning: There\'s no GPU available on this machine," "training will be performed on CPU.")
52
+ n_gpu_use = 0
53
+ if n_gpu_use > n_gpu:
54
+ self.logger.warning(
55
+ "Warning: The number of GPU\'s configured to use is {}, but only {} are available " "on this machine.".format(
56
+ n_gpu_use, n_gpu))
57
+ n_gpu_use = n_gpu
58
+ device = torch.device('cuda:0' if n_gpu_use > 0 else 'cpu')
59
+ list_ids = list(range(n_gpu_use))
60
+ return device, list_ids
61
+
62
+ def _load_checkpoint(self, load_path):
63
+ load_path = str(load_path)
64
+ self.logger.info("Loading checkpoint: {} ...".format(load_path))
65
+ checkpoint = torch.load(load_path)
66
+ self.model.load_state_dict(checkpoint)
67
+
68
+
69
+ class Tester(BaseTester):
70
+ def __init__(self, model, criterion, metric_ftns, args, test_dataloader):
71
+ super(Tester, self).__init__(model, criterion, metric_ftns, args)
72
+ self.test_dataloader = test_dataloader
73
+
74
+ def test(self):
75
+ self.logger.info('Start to evaluate in the test set.')
76
+ self.model.eval()
77
+ log = dict()
78
+ with torch.no_grad():
79
+ test_gts, test_res = [], []
80
+ for batch_idx, (images_id, images, reports_ids, reports_masks, align_ids, align_masks) in enumerate(self.test_dataloader):
81
+ images, reports_ids, reports_masks, align_ids, align_masks = images.to(self.device), reports_ids.to(self.device), \
82
+ reports_masks.to(self.device), align_ids.to(self.device), align_masks.to(self.device)
83
+ output = self.model(reports_ids, align_ids, align_masks, images, mode='sample')
84
+ reports = self.model.tokenizer.decode_batch(output.cpu().numpy())
85
+ ground_truths = self.model.tokenizer.decode_batch(reports_ids[:, 1:].cpu().numpy())
86
+ test_res.extend(reports)
87
+ test_gts.extend(ground_truths)
88
+
89
+ test_met = self.metric_ftns({i: [gt] for i, gt in enumerate(test_gts)},
90
+ {i: [re] for i, re in enumerate(test_res)})
91
+ log.update(**{'test_' + k: v for k, v in test_met.items()})
92
+ print(log)
93
+
94
+ test_res, test_gts = pd.DataFrame(test_res), pd.DataFrame(test_gts)
95
+ test_res.to_csv(os.path.join(self.save_dir, "res.csv"), index=False, header=False)
96
+ test_gts.to_csv(os.path.join(self.save_dir, "gts.csv"), index=False, header=False)
97
+
98
+ return log
99
+
100
+ def plot(self):
101
+ assert self.args.batch_size == 1 and self.args.beam_size == 1
102
+ self.logger.info('Start to plot attention weights in the test set.')
103
+ os.makedirs(os.path.join(self.save_dir, "attentions"), exist_ok=True)
104
+ os.makedirs(os.path.join(self.save_dir, "attentions_entities"), exist_ok=True)
105
+ ner = spacy.load("en_core_sci_sm")
106
+ mean = torch.tensor((0.485, 0.456, 0.406))
107
+ std = torch.tensor((0.229, 0.224, 0.225))
108
+ mean = mean[:, None, None]
109
+ std = std[:, None, None]
110
+
111
+ self.model.eval()
112
+ with torch.no_grad():
113
+ for batch_idx, (images_id, images, reports_ids, reports_masks) in tqdm(enumerate(self.test_dataloader)):
114
+ images, reports_ids, reports_masks = images.to(self.device), reports_ids.to(
115
+ self.device), reports_masks.to(self.device)
116
+ output, _ = self.model(images, mode='sample')
117
+ image = torch.clamp((images[0].cpu() * std + mean) * 255, 0, 255).int().cpu().numpy()
118
+ report = self.model.tokenizer.decode_batch(output.cpu().numpy())[0].split()
119
+
120
+ char2word = [idx for word_idx, word in enumerate(report) for idx in [word_idx] * (len(word) + 1)][:-1]
121
+
122
+ attention_weights = self.model.encoder_decoder.attention_weights[:-1]
123
+ assert len(attention_weights) == len(report)
124
+ for word_idx, (attns, word) in enumerate(zip(attention_weights, report)):
125
+ for layer_idx, attn in enumerate(attns):
126
+ os.makedirs(os.path.join(self.save_dir, "attentions", "{:04d}".format(batch_idx),
127
+ "layer_{}".format(layer_idx)), exist_ok=True)
128
+
129
+ heatmap = generate_heatmap(image, attn.mean(1).squeeze())
130
+ cv2.imwrite(os.path.join(self.save_dir, "attentions", "{:04d}".format(batch_idx),
131
+ "layer_{}".format(layer_idx), "{:04d}_{}.png".format(word_idx, word)),
132
+ heatmap)
133
+
134
+ for ne_idx, ne in enumerate(ner(" ".join(report)).ents):
135
+ for layer_idx in range(len(attention_weights[0])):
136
+ os.makedirs(os.path.join(self.save_dir, "attentions_entities", "{:04d}".format(batch_idx),
137
+ "layer_{}".format(layer_idx)), exist_ok=True)
138
+ attn = [attns[layer_idx] for attns in
139
+ attention_weights[char2word[ne.start_char]:char2word[ne.end_char] + 1]]
140
+ attn = np.concatenate(attn, axis=2)
141
+ heatmap = generate_heatmap(image, attn.mean(1).mean(1).squeeze())
142
+ cv2.imwrite(os.path.join(self.save_dir, "attentions_entities", "{:04d}".format(batch_idx),
143
+ "layer_{}".format(layer_idx), "{:04d}_{}.png".format(ne_idx, ne)),
144
+ heatmap)
modules/tokenizers.py ADDED
@@ -0,0 +1,95 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import re
3
+ from collections import Counter
4
+
5
+
6
+ class Tokenizer(object):
7
+ def __init__(self, args):
8
+ self.ann_path = args.ann_path
9
+ self.threshold = args.threshold
10
+ self.dataset_name = args.dataset_name
11
+ if self.dataset_name == 'iu_xray':
12
+ self.clean_report = self.clean_report_iu_xray
13
+ else:
14
+ self.clean_report = self.clean_report_mimic_cxr
15
+ self.ann = json.loads(open(self.ann_path, 'r').read())
16
+ self.token2idx, self.idx2token = self.create_vocabulary()
17
+
18
+ def create_vocabulary(self):
19
+ total_tokens = []
20
+
21
+ for example in self.ann['train']:
22
+ tokens = self.clean_report(example['report']).split()
23
+ for token in tokens:
24
+ total_tokens.append(token)
25
+
26
+ counter = Counter(total_tokens)
27
+ vocab = [k for k, v in counter.items() if v >= self.threshold] + ['<unk>']
28
+ vocab.sort()
29
+ token2idx, idx2token = {}, {}
30
+ for idx, token in enumerate(vocab):
31
+ token2idx[token] = idx + 1
32
+ idx2token[idx + 1] = token
33
+ return token2idx, idx2token
34
+
35
+ def clean_report_iu_xray(self, report):
36
+ report_cleaner = lambda t: t.replace('..', '.').replace('..', '.').replace('..', '.').replace('1. ', '') \
37
+ .replace('. 2. ', '. ').replace('. 3. ', '. ').replace('. 4. ', '. ').replace('. 5. ', '. ') \
38
+ .replace(' 2. ', '. ').replace(' 3. ', '. ').replace(' 4. ', '. ').replace(' 5. ', '. ') \
39
+ .strip().lower().split('. ')
40
+ sent_cleaner = lambda t: re.sub('[.,?;*!%^&_+():-\[\]{}]', '', t.replace('"', '').replace('/', '').
41
+ replace('\\', '').replace("'", '').strip().lower())
42
+ tokens = [sent_cleaner(sent) for sent in report_cleaner(report) if sent_cleaner(sent) != []]
43
+ report = ' . '.join(tokens) + ' .'
44
+ return report
45
+
46
+ def clean_report_mimic_cxr(self, report):
47
+ report_cleaner = lambda t: t.replace('\n', ' ').replace('__', '_').replace('__', '_').replace('__', '_') \
48
+ .replace('__', '_').replace('__', '_').replace('__', '_').replace('__', '_').replace(' ', ' ') \
49
+ .replace(' ', ' ').replace(' ', ' ').replace(' ', ' ').replace(' ', ' ').replace(' ', ' ') \
50
+ .replace('..', '.').replace('..', '.').replace('..', '.').replace('..', '.').replace('..', '.') \
51
+ .replace('..', '.').replace('..', '.').replace('..', '.').replace('1. ', '').replace('. 2. ', '. ') \
52
+ .replace('. 3. ', '. ').replace('. 4. ', '. ').replace('. 5. ', '. ').replace(' 2. ', '. ') \
53
+ .replace(' 3. ', '. ').replace(' 4. ', '. ').replace(' 5. ', '. ') \
54
+ .strip().lower().split('. ')
55
+ sent_cleaner = lambda t: re.sub('[.,?;*!%^&_+():-\[\]{}]', '', t.replace('"', '').replace('/', '')
56
+ .replace('\\', '').replace("'", '').strip().lower())
57
+ tokens = [sent_cleaner(sent) for sent in report_cleaner(report) if sent_cleaner(sent) != []]
58
+ report = ' . '.join(tokens) + ' .'
59
+ return report
60
+
61
+ def get_token_by_id(self, id):
62
+ return self.idx2token[id]
63
+
64
+ def get_id_by_token(self, token):
65
+ if token not in self.token2idx:
66
+ return self.token2idx['<unk>']
67
+ return self.token2idx[token]
68
+
69
+ def get_vocab_size(self):
70
+ return len(self.token2idx)
71
+
72
+ def __call__(self, report):
73
+ tokens = self.clean_report(report).split()
74
+ ids = []
75
+ for token in tokens:
76
+ ids.append(self.get_id_by_token(token))
77
+ ids = [0] + ids + [0]
78
+ return ids
79
+
80
+ def decode(self, ids):
81
+ txt = ''
82
+ for i, idx in enumerate(ids):
83
+ if idx > 0:
84
+ if i >= 1:
85
+ txt += ' '
86
+ txt += self.idx2token[idx]
87
+ else:
88
+ break
89
+ return txt
90
+
91
+ def decode_batch(self, ids_batch):
92
+ out = []
93
+ for ids in ids_batch:
94
+ out.append(self.decode(ids))
95
+ return out
modules/trainer.py ADDED
@@ -0,0 +1,255 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from abc import abstractmethod
3
+ import json
4
+ import time
5
+ import torch
6
+ import pandas as pd
7
+ from numpy import inf
8
+
9
+
10
+ class BaseTrainer(object):
11
+ def __init__(self, model, criterion, metric_ftns, optimizer, args):
12
+ self.args = args
13
+
14
+ # setup GPU device if available, move model into configured device
15
+ self.device, device_ids = self._prepare_device(args.n_gpu)
16
+ self.model = model.to(self.device)
17
+ if len(device_ids) > 1:
18
+ self.model = torch.nn.DataParallel(model, device_ids=device_ids)
19
+
20
+ self.criterion = criterion
21
+ self.metric_ftns = metric_ftns
22
+ self.optimizer = optimizer
23
+
24
+ self.epochs = self.args.epochs
25
+ self.save_period = self.args.save_period
26
+
27
+ self.mnt_mode = args.monitor_mode
28
+ self.mnt_metric = 'val_' + args.monitor_metric
29
+ self.mnt_metric_test = 'test_' + args.monitor_metric
30
+ assert self.mnt_mode in ['min', 'max']
31
+
32
+ self.mnt_best = inf if self.mnt_mode == 'min' else -inf
33
+ self.early_stop = getattr(self.args, 'early_stop', inf)
34
+
35
+ self.start_epoch = 1
36
+ self.checkpoint_dir = args.save_dir
37
+
38
+ if not os.path.exists(self.checkpoint_dir):
39
+ os.makedirs(self.checkpoint_dir)
40
+
41
+ if args.resume is not None:
42
+ self._resume_checkpoint(args.resume)
43
+
44
+ self.best_recorder = {'val': {self.mnt_metric: self.mnt_best},
45
+ 'test': {self.mnt_metric_test: self.mnt_best}}
46
+
47
+ @abstractmethod
48
+ def _train_epoch(self, epoch):
49
+ raise NotImplementedError
50
+
51
+ def train(self):
52
+ not_improved_count = 0
53
+ for epoch in range(self.start_epoch, self.epochs + 1):
54
+ result = self._train_epoch(epoch)
55
+
56
+ # save logged informations into log dict
57
+ log = {'epoch': epoch}
58
+ log.update(result)
59
+ self._record_best(log)
60
+
61
+ # print logged informations to the screen
62
+ for key, value in log.items():
63
+ print('\t{:15s}: {}'.format(str(key), value))
64
+
65
+ # evaluate model performance according to configured metric, save best checkpoint as model_best
66
+ best = False
67
+ if self.mnt_mode != 'off':
68
+ try:
69
+ # check whether model performance improved or not, according to specified metric(mnt_metric)
70
+ improved = (self.mnt_mode == 'min' and log[self.mnt_metric] <= self.mnt_best) or \
71
+ (self.mnt_mode == 'max' and log[self.mnt_metric] >= self.mnt_best)
72
+ except KeyError:
73
+ print("Warning: Metric '{}' is not found. " "Model performance monitoring is disabled.".format(
74
+ self.mnt_metric))
75
+ self.mnt_mode = 'off'
76
+ improved = False
77
+
78
+ if improved:
79
+ self.mnt_best = log[self.mnt_metric]
80
+ not_improved_count = 0
81
+ best = True
82
+ else:
83
+ not_improved_count += 1
84
+
85
+ if not_improved_count > self.early_stop:
86
+ print("Validation performance didn\'t improve for {} epochs. " "Training stops.".format(
87
+ self.early_stop))
88
+ break
89
+
90
+ if epoch % self.save_period == 0:
91
+ self._save_checkpoint(epoch, save_best=best)
92
+ self._print_best()
93
+ self._print_best_to_file()
94
+
95
+ def _print_best_to_file(self):
96
+ crt_time = time.asctime(time.localtime(time.time()))
97
+ self.best_recorder['val']['time'] = crt_time
98
+ self.best_recorder['test']['time'] = crt_time
99
+ self.best_recorder['val']['seed'] = self.args.seed
100
+ self.best_recorder['test']['seed'] = self.args.seed
101
+ self.best_recorder['val']['best_model_from'] = 'val'
102
+ self.best_recorder['test']['best_model_from'] = 'test'
103
+
104
+ if not os.path.exists(self.args.record_dir):
105
+ os.makedirs(self.args.record_dir)
106
+ record_path = os.path.join(self.args.record_dir, self.args.dataset_name+'.csv')
107
+ if not os.path.exists(record_path):
108
+ record_table = pd.DataFrame()
109
+ else:
110
+ record_table = pd.read_csv(record_path)
111
+ record_table = record_table.append(self.best_recorder['val'], ignore_index=True)
112
+ record_table = record_table.append(self.best_recorder['test'], ignore_index=True)
113
+ record_table.to_csv(record_path, index=False)
114
+
115
+ def _prepare_device(self, n_gpu_use):
116
+ n_gpu = torch.cuda.device_count()
117
+ if n_gpu_use > 0 and n_gpu == 0:
118
+ print("Warning: There\'s no GPU available on this machine," "training will be performed on CPU.")
119
+ n_gpu_use = 0
120
+ if n_gpu_use > n_gpu:
121
+ print(
122
+ "Warning: The number of GPU\'s configured to use is {}, but only {} are available " "on this machine.".format(
123
+ n_gpu_use, n_gpu))
124
+ n_gpu_use = n_gpu
125
+ device = torch.device('cuda:0' if n_gpu_use > 0 else 'cpu')
126
+ list_ids = list(range(n_gpu_use))
127
+ return device, list_ids
128
+
129
+ def _save_checkpoint(self, epoch, save_best=False):
130
+ state = {
131
+ 'epoch': epoch,
132
+ 'state_dict': self.model.state_dict(),
133
+ 'optimizer': self.optimizer.state_dict(),
134
+ 'monitor_best': self.mnt_best
135
+ }
136
+ filename = os.path.join(self.checkpoint_dir, 'current_checkpoint.pth')
137
+ torch.save(state, filename)
138
+ print("Saving checkpoint: {} ...".format(filename))
139
+ if save_best:
140
+ best_path = os.path.join(self.checkpoint_dir, 'model_best.pth')
141
+ torch.save(state, best_path)
142
+ print("Saving current best: model_best.pth ...")
143
+
144
+ def _resume_checkpoint(self, resume_path):
145
+ resume_path = str(resume_path)
146
+ print("Loading checkpoint: {} ...".format(resume_path))
147
+ checkpoint = torch.load(resume_path)
148
+ self.start_epoch = checkpoint['epoch'] + 1
149
+ self.mnt_best = checkpoint['monitor_best']
150
+ self.model.load_state_dict(checkpoint['state_dict'])
151
+ self.optimizer.load_state_dict(checkpoint['optimizer'])
152
+
153
+ print("Checkpoint loaded. Resume training from epoch {}".format(self.start_epoch))
154
+
155
+ def _record_best(self, log):
156
+ improved_val = (self.mnt_mode == 'min' and log[self.mnt_metric] <= self.best_recorder['val'][
157
+ self.mnt_metric]) or \
158
+ (self.mnt_mode == 'max' and log[self.mnt_metric] >= self.best_recorder['val'][self.mnt_metric])
159
+ if improved_val:
160
+ self.best_recorder['val'].update(log)
161
+
162
+ improved_test = (self.mnt_mode == 'min' and log[self.mnt_metric_test] <= self.best_recorder['test'][
163
+ self.mnt_metric_test]) or \
164
+ (self.mnt_mode == 'max' and log[self.mnt_metric_test] >= self.best_recorder['test'][
165
+ self.mnt_metric_test])
166
+ if improved_test:
167
+ self.best_recorder['test'].update(log)
168
+
169
+ def _print_best(self):
170
+ print('Best results (w.r.t {}) in validation set:'.format(self.args.monitor_metric))
171
+ for key, value in self.best_recorder['val'].items():
172
+ print('\t{:15s}: {}'.format(str(key), value))
173
+
174
+ print('Best results (w.r.t {}) in test set:'.format(self.args.monitor_metric))
175
+ for key, value in self.best_recorder['test'].items():
176
+ print('\t{:15s}: {}'.format(str(key), value))
177
+
178
+
179
+ if not os.path.exists('valreports/'):
180
+ os.makedirs('valreports/')
181
+ if not os.path.exists('testreports/'):
182
+ os.makedirs('testreports/')
183
+
184
+ class Trainer(BaseTrainer):
185
+ def __init__(self, model, criterion, metric_ftns, optimizer, args, lr_scheduler, train_dataloader, val_dataloader,
186
+ test_dataloader):
187
+ super(Trainer, self).__init__(model, criterion, metric_ftns, optimizer, args)
188
+ self.lr_scheduler = lr_scheduler
189
+ self.train_dataloader = train_dataloader
190
+ self.val_dataloader = val_dataloader
191
+ self.test_dataloader = test_dataloader
192
+
193
+ def _train_epoch(self, epoch):
194
+
195
+ train_loss = 0
196
+ self.model.train()
197
+ for batch_idx, (images_id, images, reports_ids, reports_masks) in enumerate(self.train_dataloader):
198
+ images, reports_ids, reports_masks = images.to(self.device), reports_ids.to(self.device), reports_masks.to(
199
+ self.device)
200
+ output = self.model(images, reports_ids, mode='train')
201
+ loss = self.criterion(output, reports_ids, reports_masks)
202
+ train_loss += loss.item()
203
+ self.optimizer.zero_grad()
204
+ loss.backward()
205
+ torch.nn.utils.clip_grad_value_(self.model.parameters(), 0.1)
206
+ self.optimizer.step()
207
+ log = {'train_loss': train_loss / len(self.train_dataloader)}
208
+
209
+
210
+ self.model.eval()
211
+ with torch.no_grad():
212
+ result_report_val = []
213
+ val_gts, val_res = [], []
214
+ for batch_idx, (images_id, images, reports_ids, reports_masks) in enumerate(self.val_dataloader):
215
+ images, reports_ids, reports_masks = images.to(self.device), reports_ids.to(
216
+ self.device), reports_masks.to(self.device)
217
+ output = self.model(images, mode='sample')
218
+ reports = self.model.tokenizer.decode_batch(output.cpu().numpy())
219
+ for i in range(reports_ids.shape[0]):
220
+ temp1 = {'reports_ids': images_id[i], 'reports': reports[i]}
221
+ result_report_val.append(temp1)
222
+ ground_truths = self.model.tokenizer.decode_batch(reports_ids[:, 1:].cpu().numpy())
223
+ val_res.extend(reports)
224
+ val_gts.extend(ground_truths)
225
+ val_met = self.metric_ftns({i: [gt] for i, gt in enumerate(val_gts)},
226
+ {i: [re] for i, re in enumerate(val_res)})
227
+ log.update(**{'val_' + k: v for k, v in val_met.items()})
228
+ resFileval = 'valreports/mixed-' + str(epoch) + '.json'
229
+ json.dump(result_report_val, open(resFileval, 'w'))
230
+
231
+
232
+ self.model.eval()
233
+ with torch.no_grad():
234
+ result_report_test = []
235
+ test_gts, test_res = [], []
236
+ for batch_idx, (images_id, images, reports_ids, reports_masks) in enumerate(self.test_dataloader):
237
+ images, reports_ids, reports_masks = images.to(self.device), reports_ids.to(
238
+ self.device), reports_masks.to(self.device)
239
+ output = self.model(images, mode='sample')
240
+ reports = self.model.tokenizer.decode_batch(output.cpu().numpy())
241
+ # print('reportsreportsreportsreports',images_id,reports)
242
+ for i in range(reports_ids.shape[0]):
243
+ temp = {'reports_ids': images_id[i], 'reports': reports[i]}
244
+ result_report_test.append(temp)
245
+ ground_truths = self.model.tokenizer.decode_batch(reports_ids[:, 1:].cpu().numpy())
246
+ test_res.extend(reports)
247
+ test_gts.extend(ground_truths)
248
+ test_met = self.metric_ftns({i: [gt] for i, gt in enumerate(test_gts)},
249
+ {i: [re] for i, re in enumerate(test_res)})
250
+ log.update(**{'test_' + k: v for k, v in test_met.items()})
251
+ resFiletest = 'testreports/mixed-' + str(epoch) + '.json'
252
+ json.dump(result_report_test, open(resFiletest, 'w'))
253
+ self.lr_scheduler.step()
254
+
255
+ return log
modules/utils.py ADDED
@@ -0,0 +1,55 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+
3
+
4
+ def penalty_builder(penalty_config):
5
+ if penalty_config == '':
6
+ return lambda x, y: y
7
+ pen_type, alpha = penalty_config.split('_')
8
+ alpha = float(alpha)
9
+ if pen_type == 'wu':
10
+ return lambda x, y: length_wu(x, y, alpha)
11
+ if pen_type == 'avg':
12
+ return lambda x, y: length_average(x, y, alpha)
13
+
14
+
15
+ def length_wu(length, logprobs, alpha=0.):
16
+ """
17
+ NMT length re-ranking score from
18
+ "Google's Neural Machine Translation System" :cite:`wu2016google`.
19
+ """
20
+
21
+ modifier = (((5 + length) ** alpha) /
22
+ ((5 + 1) ** alpha))
23
+ return logprobs / modifier
24
+
25
+
26
+ def length_average(length, logprobs, alpha=0.):
27
+ """
28
+ Returns the average probability of tokens in a sequence.
29
+ """
30
+ return logprobs / length
31
+
32
+
33
+ def split_tensors(n, x):
34
+ if torch.is_tensor(x):
35
+ assert x.shape[0] % n == 0
36
+ x = x.reshape(x.shape[0] // n, n, *x.shape[1:]).unbind(1)
37
+ elif type(x) is list or type(x) is tuple:
38
+ x = [split_tensors(n, _) for _ in x]
39
+ elif x is None:
40
+ x = [None] * n
41
+ return x
42
+
43
+
44
+ def repeat_tensors(n, x):
45
+ """
46
+ For a tensor of size Bx..., we repeat it n times, and make it Bnx...
47
+ For collections, do nested repeat
48
+ """
49
+ if torch.is_tensor(x):
50
+ x = x.unsqueeze(1) # Bx1x...
51
+ x = x.expand(-1, n, *([-1] * len(x.shape[2:]))) # Bxnx...
52
+ x = x.reshape(x.shape[0] * n, *x.shape[2:]) # Bnx...
53
+ elif type(x) is list or type(x) is tuple:
54
+ x = [repeat_tensors(n, _) for _ in x]
55
+ return x
modules/visual_extractor.py ADDED
@@ -0,0 +1,53 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+
3
+ os.environ['HF_ENDPOINT'] = 'https://hf-mirror.com'
4
+
5
+ from medclip import MedCLIPModel, MedCLIPVisionModelViT
6
+ from medclip import MedCLIPProcessor
7
+ from PIL import Image
8
+ import torch
9
+ import torch.nn as nn
10
+ import torchvision.models as models
11
+ import torch.nn.functional as F
12
+
13
+ class VisualExtractor(nn.Module):
14
+ # prepare for the demo image and text
15
+ def __init__(self, args):
16
+ super(VisualExtractor, self).__init__()
17
+ self.model = MedCLIPModel(vision_cls=MedCLIPVisionModelViT)
18
+ self.model.from_pretrained()
19
+ self.model.cuda()
20
+ self.processor = MedCLIPProcessor()
21
+ with torch.no_grad():
22
+ self.prompt = torch.load('prompt/prompt.pth')
23
+
24
+
25
+ def forward(self, images):
26
+ a=[]
27
+ for i in images:
28
+ inputs = self.processor( text="lungs",images=i,return_tensors="pt",padding=True)
29
+ outputs = self.model(**inputs)
30
+ feats = outputs['img_embeds']
31
+ a.append(feats)
32
+ batch_feats = torch.stack(a, dim=0)
33
+
34
+ ha = []
35
+ for i in range(batch_feats.shape[0]):
36
+ b = batch_feats[i].unsqueeze(1)
37
+ b = b.repeat(self.prompt.shape[0], 1, 1).transpose(-2, -1)
38
+ c_t = torch.bmm(self.prompt, b)
39
+ c_t = c_t.float()
40
+ alpha = F.softmax(c_t)
41
+ aa = alpha * self.prompt
42
+ sum_a = aa.sum(axis=0)
43
+ ha.append(sum_a)
44
+ featsem = torch.stack(ha, dim=0)
45
+
46
+ feats = torch.cat((featsem, batch_feats), dim=2)
47
+
48
+ patch_feats = feats.repeat(1, 49, 1)
49
+ batch_feats1 = feats.squeeze(1)
50
+ avg_feats = batch_feats1
51
+
52
+
53
+ return patch_feats, avg_feats
prompt/prompt.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:4b03692f5ba61e9d50d10556cdbb724ed6249668873bad099bc6548af618a7d0
3
+ size 20480747
pycocoevalcap/README.md ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Microsoft COCO Caption Evaluation Tools <br />
2
+ ---
3
+
4
+ Modified the code to work with Python 3. <br />
5
+
6
+ ### Requirements
7
+ * Python 3.x
8
+ * Java 1.8
9
+ * pycocotools
10
+
11
+ ---
12
+
13
+ ### Tested on
14
+ * Windows 10, Python 3.5.
15
+
16
+ ---
17
+ ### To fix Windows JVM memory error: <br />
18
+ Add the following in System Variables <br />
19
+ &nbsp;&nbsp;&nbsp;&nbsp;Variable name : _JAVA_OPTIONS <br />
20
+ &nbsp;&nbsp;&nbsp;&nbsp;Variable value : -Xmx1024M <br />
21
+
22
+ ---
23
+ Original code : https://github.com/tylin/coco-caption <br />
pycocoevalcap/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ __author__ = 'tylin'
pycocoevalcap/bleu/LICENSE ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Copyright (c) 2015 Xinlei Chen, Hao Fang, Tsung-Yi Lin, and Ramakrishna Vedantam
2
+
3
+ Permission is hereby granted, free of charge, to any person obtaining a copy
4
+ of this software and associated documentation files (the "Software"), to deal
5
+ in the Software without restriction, including without limitation the rights
6
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
7
+ copies of the Software, and to permit persons to whom the Software is
8
+ furnished to do so, subject to the following conditions:
9
+
10
+ The above copyright notice and this permission notice shall be included in
11
+ all copies or substantial portions of the Software.
12
+
13
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
14
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
15
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
16
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
17
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
18
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
19
+ THE SOFTWARE.
pycocoevalcap/bleu/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ __author__ = 'tylin'
pycocoevalcap/bleu/bleu.py ADDED
@@ -0,0 +1,57 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+ #
3
+ # File Name : bleu.py
4
+ #
5
+ # Description : Wrapper for BLEU scorer.
6
+ #
7
+ # Creation Date : 06-01-2015
8
+ # Last Modified : Thu 19 Mar 2015 09:13:28 PM PDT
9
+ # Authors : Hao Fang <hfang@uw.edu> and Tsung-Yi Lin <tl483@cornell.edu>
10
+
11
+ # Last modified : Wed 22 May 2019 08:10:00 PM EDT
12
+ # By Sabarish Sivanath
13
+ # To support Python 3
14
+
15
+ from .bleu_scorer import BleuScorer
16
+
17
+
18
+ class Bleu:
19
+ def __init__(self, n=4):
20
+ # default compute Blue score up to 4
21
+ self._n = n
22
+ self._hypo_for_image = {}
23
+ self.ref_for_image = {}
24
+
25
+ def compute_score(self, gts, res, score_option = 'closest', verbose = 1):
26
+ '''
27
+ Inputs:
28
+ gts - ground truths
29
+ res - predictions
30
+ score_option - {shortest, closest, average}
31
+ verbose - 1 or 0
32
+ Outputs:
33
+ Blue scores
34
+ '''
35
+ assert(gts.keys() == res.keys())
36
+ imgIds = gts.keys()
37
+
38
+ bleu_scorer = BleuScorer(n=self._n)
39
+ for id in imgIds:
40
+ hypo = res[id]
41
+ ref = gts[id]
42
+
43
+ # Sanity check.
44
+ assert(type(hypo) is list)
45
+ assert(len(hypo) == 1)
46
+ assert(type(ref) is list)
47
+ #assert(len(ref) >= 1)
48
+
49
+ bleu_scorer += (hypo[0], ref)
50
+
51
+ score, scores = bleu_scorer.compute_score(option = score_option, verbose =verbose)
52
+
53
+ # return (bleu, bleu_info)
54
+ return score, scores
55
+
56
+ def method(self):
57
+ return "Bleu"
pycocoevalcap/bleu/bleu_scorer.py ADDED
@@ -0,0 +1,268 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # bleu_scorer.py
2
+ # David Chiang <chiang@isi.edu>
3
+
4
+ # Copyright (c) 2004-2006 University of Maryland. All rights
5
+ # reserved. Do not redistribute without permission from the
6
+ # author. Not for commercial use.
7
+
8
+ # Modified by:
9
+ # Hao Fang <hfang@uw.edu>
10
+ # Tsung-Yi Lin <tl483@cornell.edu>
11
+
12
+ # Last modified : Wed 22 May 2019 08:10:00 PM EDT
13
+ # By Sabarish Sivanath
14
+ # To support Python 3
15
+
16
+ '''Provides:
17
+ cook_refs(refs, n=4): Transform a list of reference sentences as strings into a form usable by cook_test().
18
+ cook_test(test, refs, n=4): Transform a test sentence as a string (together with the cooked reference sentences) into a form usable by score_cooked().
19
+ '''
20
+
21
+ import copy
22
+ import sys, math, re
23
+ from collections import defaultdict
24
+
25
+ def precook(s, n=4, out=False):
26
+ """Takes a string as input and returns an object that can be given to
27
+ either cook_refs or cook_test. This is optional: cook_refs and cook_test
28
+ can take string arguments as well."""
29
+ words = s.split()
30
+ counts = defaultdict(int)
31
+ for k in range(1,n+1):
32
+ for i in range(len(words)-k+1):
33
+ ngram = tuple(words[i:i+k])
34
+ counts[ngram] += 1
35
+ return (len(words), counts)
36
+
37
+ def cook_refs(refs, eff=None, n=4): ## lhuang: oracle will call with "average"
38
+ '''Takes a list of reference sentences for a single segment
39
+ and returns an object that encapsulates everything that BLEU
40
+ needs to know about them.'''
41
+
42
+ reflen = []
43
+ maxcounts = {}
44
+ for ref in refs:
45
+ rl, counts = precook(ref, n)
46
+ reflen.append(rl)
47
+ for (ngram,count) in counts.items():
48
+ maxcounts[ngram] = max(maxcounts.get(ngram,0), count)
49
+
50
+ # Calculate effective reference sentence length.
51
+ if eff == "shortest":
52
+ reflen = min(reflen)
53
+ elif eff == "average":
54
+ reflen = float(sum(reflen))/len(reflen)
55
+
56
+ ## lhuang: N.B.: leave reflen computaiton to the very end!!
57
+
58
+ ## lhuang: N.B.: in case of "closest", keep a list of reflens!! (bad design)
59
+
60
+ return (reflen, maxcounts)
61
+
62
+ def cook_test(test, refs , eff=None, n=4):
63
+ '''Takes a test sentence and returns an object that
64
+ encapsulates everything that BLEU needs to know about it.'''
65
+
66
+ reflen = refs[0]
67
+ refmaxcounts = refs[1]
68
+
69
+ testlen, counts = precook(test, n, True)
70
+
71
+ result = {}
72
+
73
+ # Calculate effective reference sentence length.
74
+
75
+ if eff == "closest":
76
+ result["reflen"] = min((abs(l-testlen), l) for l in reflen)[1]
77
+ else: ## i.e., "average" or "shortest" or None
78
+ result["reflen"] = reflen
79
+
80
+ result["testlen"] = testlen
81
+
82
+ result["guess"] = [max(0,testlen-k+1) for k in range(1,n+1)]
83
+
84
+ result['correct'] = [0]*n
85
+ for (ngram, count) in counts.items():
86
+ result["correct"][len(ngram)-1] += min(refmaxcounts.get(ngram,0), count)
87
+
88
+ return result
89
+
90
+ class BleuScorer(object):
91
+ """Bleu scorer.
92
+ """
93
+
94
+ __slots__ = "n", "crefs", "ctest", "_score", "_ratio", "_testlen", "_reflen", "special_reflen"
95
+ # special_reflen is used in oracle (proportional effective ref len for a node).
96
+
97
+ def copy(self):
98
+ ''' copy the refs.'''
99
+ new = BleuScorer(n=self.n)
100
+ new.ctest = copy.copy(self.ctest)
101
+ new.crefs = copy.copy(self.crefs)
102
+ new._score = None
103
+ return new
104
+
105
+ def __init__(self, test=None, refs=None, n=4, special_reflen=None):
106
+ ''' singular instance '''
107
+
108
+ self.n = n
109
+ self.crefs = []
110
+ self.ctest = []
111
+ self.cook_append(test, refs)
112
+ self.special_reflen = special_reflen
113
+
114
+ def cook_append(self, test, refs):
115
+ '''called by constructor and __iadd__ to avoid creating new instances.'''
116
+
117
+ if refs is not None:
118
+ self.crefs.append(cook_refs(refs))
119
+ if test is not None:
120
+ cooked_test = cook_test(test, self.crefs[-1])
121
+ self.ctest.append(cooked_test) ## N.B.: -1
122
+ else:
123
+ self.ctest.append(None) # lens of crefs and ctest have to match
124
+
125
+ self._score = None ## need to recompute
126
+
127
+ def ratio(self, option=None):
128
+ self.compute_score(option=option)
129
+ return self._ratio
130
+
131
+ def score_ratio(self, option=None):
132
+ '''return (bleu, len_ratio) pair'''
133
+ return (self.fscore(option=option), self.ratio(option=option))
134
+
135
+ def score_ratio_str(self, option=None):
136
+ return "%.4f (%.2f)" % self.score_ratio(option)
137
+
138
+ def reflen(self, option=None):
139
+ self.compute_score(option=option)
140
+ return self._reflen
141
+
142
+ def testlen(self, option=None):
143
+ self.compute_score(option=option)
144
+ return self._testlen
145
+
146
+ def retest(self, new_test):
147
+ if type(new_test) is str:
148
+ new_test = [new_test]
149
+ assert len(new_test) == len(self.crefs), new_test
150
+ self.ctest = []
151
+ for t, rs in zip(new_test, self.crefs):
152
+ self.ctest.append(cook_test(t, rs))
153
+ self._score = None
154
+
155
+ return self
156
+
157
+ def rescore(self, new_test):
158
+ ''' replace test(s) with new test(s), and returns the new score.'''
159
+
160
+ return self.retest(new_test).compute_score()
161
+
162
+ def size(self):
163
+ assert len(self.crefs) == len(self.ctest), "refs/test mismatch! %d<>%d" % (len(self.crefs), len(self.ctest))
164
+ return len(self.crefs)
165
+
166
+ def __iadd__(self, other):
167
+ '''add an instance (e.g., from another sentence).'''
168
+
169
+ if type(other) is tuple:
170
+ ## avoid creating new BleuScorer instances
171
+ self.cook_append(other[0], other[1])
172
+ else:
173
+ assert self.compatible(other), "incompatible BLEUs."
174
+ self.ctest.extend(other.ctest)
175
+ self.crefs.extend(other.crefs)
176
+ self._score = None ## need to recompute
177
+
178
+ return self
179
+
180
+ def compatible(self, other):
181
+ return isinstance(other, BleuScorer) and self.n == other.n
182
+
183
+ def single_reflen(self, option="average"):
184
+ return self._single_reflen(self.crefs[0][0], option)
185
+
186
+ def _single_reflen(self, reflens, option=None, testlen=None):
187
+
188
+ if option == "shortest":
189
+ reflen = min(reflens)
190
+ elif option == "average":
191
+ reflen = float(sum(reflens))/len(reflens)
192
+ elif option == "closest":
193
+ reflen = min((abs(l-testlen), l) for l in reflens)[1]
194
+ else:
195
+ assert False, "unsupported reflen option %s" % option
196
+
197
+ return reflen
198
+
199
+ def recompute_score(self, option=None, verbose=0):
200
+ self._score = None
201
+ return self.compute_score(option, verbose)
202
+
203
+ def compute_score(self, option=None, verbose=0):
204
+ n = self.n
205
+ small = 1e-9
206
+ tiny = 1e-15 ## so that if guess is 0 still return 0
207
+ bleu_list = [[] for _ in range(n)]
208
+
209
+ if self._score is not None:
210
+ return self._score
211
+
212
+ if option is None:
213
+ option = "average" if len(self.crefs) == 1 else "closest"
214
+
215
+ self._testlen = 0
216
+ self._reflen = 0
217
+ totalcomps = {'testlen':0, 'reflen':0, 'guess':[0]*n, 'correct':[0]*n}
218
+
219
+ # for each sentence
220
+ for comps in self.ctest:
221
+ testlen = comps['testlen']
222
+ self._testlen += testlen
223
+
224
+ if self.special_reflen is None: ## need computation
225
+ reflen = self._single_reflen(comps['reflen'], option, testlen)
226
+ else:
227
+ reflen = self.special_reflen
228
+
229
+ self._reflen += reflen
230
+
231
+ for key in ['guess','correct']:
232
+ for k in range(n):
233
+ totalcomps[key][k] += comps[key][k]
234
+
235
+ # append per image bleu score
236
+ bleu = 1.
237
+ for k in range(n):
238
+ bleu *= (float(comps['correct'][k]) + tiny) \
239
+ /(float(comps['guess'][k]) + small)
240
+ bleu_list[k].append(bleu ** (1./(k+1)))
241
+ ratio = (testlen + tiny) / (reflen + small) ## N.B.: avoid zero division
242
+ if ratio < 1:
243
+ for k in range(n):
244
+ bleu_list[k][-1] *= math.exp(1 - 1/ratio)
245
+
246
+ if verbose > 1:
247
+ print(comps, reflen)
248
+
249
+ totalcomps['reflen'] = self._reflen
250
+ totalcomps['testlen'] = self._testlen
251
+
252
+ bleus = []
253
+ bleu = 1.
254
+ for k in range(n):
255
+ bleu *= float(totalcomps['correct'][k] + tiny) \
256
+ / (totalcomps['guess'][k] + small)
257
+ bleus.append(bleu ** (1./(k+1)))
258
+ ratio = (self._testlen + tiny) / (self._reflen + small) ## N.B.: avoid zero division
259
+ if ratio < 1:
260
+ for k in range(n):
261
+ bleus[k] *= math.exp(1 - 1/ratio)
262
+
263
+ if verbose > 0:
264
+ print(totalcomps)
265
+ print("ratio:", ratio)
266
+
267
+ self._score = bleus
268
+ return self._score, bleu_list
pycocoevalcap/cider/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ __author__ = 'tylin'
pycocoevalcap/cider/cider.py ADDED
@@ -0,0 +1,55 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Filename: cider.py
2
+ #
3
+ # Description: Describes the class to compute the CIDEr (Consensus-Based Image Description Evaluation) Metric
4
+ # by Vedantam, Zitnick, and Parikh (http://arxiv.org/abs/1411.5726)
5
+ #
6
+ # Creation Date: Sun Feb 8 14:16:54 2015
7
+ #
8
+ # Authors: Ramakrishna Vedantam <vrama91@vt.edu> and Tsung-Yi Lin <tl483@cornell.edu>
9
+
10
+
11
+ from .cider_scorer import CiderScorer
12
+ import pdb
13
+
14
+ class Cider:
15
+ """
16
+ Main Class to compute the CIDEr metric
17
+
18
+ """
19
+ def __init__(self, test=None, refs=None, n=4, sigma=6.0):
20
+ # set cider to sum over 1 to 4-grams
21
+ self._n = n
22
+ # set the standard deviation parameter for gaussian penalty
23
+ self._sigma = sigma
24
+
25
+ def compute_score(self, gts, res):
26
+ """
27
+ Main function to compute CIDEr score
28
+ :param hypo_for_image (dict) : dictionary with key <image> and value <tokenized hypothesis / candidate sentence>
29
+ ref_for_image (dict) : dictionary with key <image> and value <tokenized reference sentence>
30
+ :return: cider (float) : computed CIDEr score for the corpus
31
+ """
32
+
33
+ assert(gts.keys() == res.keys())
34
+ imgIds = gts.keys()
35
+
36
+ cider_scorer = CiderScorer(n=self._n, sigma=self._sigma)
37
+
38
+ for id in imgIds:
39
+ hypo = res[id]
40
+ ref = gts[id]
41
+
42
+ # Sanity check.
43
+ assert(type(hypo) is list)
44
+ assert(len(hypo) == 1)
45
+ assert(type(ref) is list)
46
+ assert(len(ref) > 0)
47
+
48
+ cider_scorer += (hypo[0], ref)
49
+
50
+ (score, scores) = cider_scorer.compute_score()
51
+
52
+ return score, scores
53
+
54
+ def method(self):
55
+ return "CIDEr"
pycocoevalcap/cider/cider_scorer.py ADDED
@@ -0,0 +1,197 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+ # Tsung-Yi Lin <tl483@cornell.edu>
3
+ # Ramakrishna Vedantam <vrama91@vt.edu>
4
+
5
+
6
+ # Last modified : Wed 22 May 2019 08:10:00 PM EDT
7
+ # By Sabarish Sivanath
8
+ # To support Python 3
9
+
10
+ import copy
11
+ from collections import defaultdict
12
+ import numpy as np
13
+ import pdb
14
+ import math
15
+
16
+ def precook(s, n=4, out=False):
17
+ """
18
+ Takes a string as input and returns an object that can be given to
19
+ either cook_refs or cook_test. This is optional: cook_refs and cook_test
20
+ can take string arguments as well.
21
+ :param s: string : sentence to be converted into ngrams
22
+ :param n: int : number of ngrams for which representation is calculated
23
+ :return: term frequency vector for occuring ngrams
24
+ """
25
+ words = s.split()
26
+ counts = defaultdict(int)
27
+ for k in range(1,n+1):
28
+ for i in range(len(words)-k+1):
29
+ ngram = tuple(words[i:i+k])
30
+ counts[ngram] += 1
31
+ return counts
32
+
33
+ def cook_refs(refs, n=4): ## lhuang: oracle will call with "average"
34
+ '''Takes a list of reference sentences for a single segment
35
+ and returns an object that encapsulates everything that BLEU
36
+ needs to know about them.
37
+ :param refs: list of string : reference sentences for some image
38
+ :param n: int : number of ngrams for which (ngram) representation is calculated
39
+ :return: result (list of dict)
40
+ '''
41
+ return [precook(ref, n) for ref in refs]
42
+
43
+ def cook_test(test, n=4):
44
+ '''Takes a test sentence and returns an object that
45
+ encapsulates everything that BLEU needs to know about it.
46
+ :param test: list of string : hypothesis sentence for some image
47
+ :param n: int : number of ngrams for which (ngram) representation is calculated
48
+ :return: result (dict)
49
+ '''
50
+ return precook(test, n, True)
51
+
52
+ class CiderScorer(object):
53
+ """CIDEr scorer.
54
+ """
55
+
56
+ def copy(self):
57
+ ''' copy the refs.'''
58
+ new = CiderScorer(n=self.n)
59
+ new.ctest = copy.copy(self.ctest)
60
+ new.crefs = copy.copy(self.crefs)
61
+ return new
62
+
63
+ def __init__(self, test=None, refs=None, n=4, sigma=6.0):
64
+ ''' singular instance '''
65
+ self.n = n
66
+ self.sigma = sigma
67
+ self.crefs = []
68
+ self.ctest = []
69
+ self.document_frequency = defaultdict(float)
70
+ self.cook_append(test, refs)
71
+ self.ref_len = None
72
+
73
+ def cook_append(self, test, refs):
74
+ '''called by constructor and __iadd__ to avoid creating new instances.'''
75
+
76
+ if refs is not None:
77
+ self.crefs.append(cook_refs(refs))
78
+ if test is not None:
79
+ self.ctest.append(cook_test(test)) ## N.B.: -1
80
+ else:
81
+ self.ctest.append(None) # lens of crefs and ctest have to match
82
+
83
+ def size(self):
84
+ assert len(self.crefs) == len(self.ctest), "refs/test mismatch! %d<>%d" % (len(self.crefs), len(self.ctest))
85
+ return len(self.crefs)
86
+
87
+ def __iadd__(self, other):
88
+ '''add an instance (e.g., from another sentence).'''
89
+
90
+ if type(other) is tuple:
91
+ ## avoid creating new CiderScorer instances
92
+ self.cook_append(other[0], other[1])
93
+ else:
94
+ self.ctest.extend(other.ctest)
95
+ self.crefs.extend(other.crefs)
96
+
97
+ return self
98
+ def compute_doc_freq(self):
99
+ '''
100
+ Compute term frequency for reference data.
101
+ This will be used to compute idf (inverse document frequency later)
102
+ The term frequency is stored in the object
103
+ :return: None
104
+ '''
105
+ for refs in self.crefs:
106
+ # refs, k ref captions of one image
107
+ for ngram in set([ngram for ref in refs for (ngram,count) in ref.items()]):
108
+ self.document_frequency[ngram] += 1
109
+ # maxcounts[ngram] = max(maxcounts.get(ngram,0), count)
110
+
111
+ def compute_cider(self):
112
+ def counts2vec(cnts):
113
+ """
114
+ Function maps counts of ngram to vector of tfidf weights.
115
+ The function returns vec, an array of dictionary that store mapping of n-gram and tf-idf weights.
116
+ The n-th entry of array denotes length of n-grams.
117
+ :param cnts:
118
+ :return: vec (array of dict), norm (array of float), length (int)
119
+ """
120
+ vec = [defaultdict(float) for _ in range(self.n)]
121
+ length = 0
122
+ norm = [0.0 for _ in range(self.n)]
123
+ for (ngram,term_freq) in cnts.items():
124
+ # give word count 1 if it doesn't appear in reference corpus
125
+ df = np.log(max(1.0, self.document_frequency[ngram]))
126
+ # ngram index
127
+ n = len(ngram)-1
128
+ # tf (term_freq) * idf (precomputed idf) for n-grams
129
+ vec[n][ngram] = float(term_freq)*(self.ref_len - df)
130
+ # compute norm for the vector. the norm will be used for computing similarity
131
+ norm[n] += pow(vec[n][ngram], 2)
132
+
133
+ if n == 1:
134
+ length += term_freq
135
+ norm = [np.sqrt(n) for n in norm]
136
+ return vec, norm, length
137
+
138
+ def sim(vec_hyp, vec_ref, norm_hyp, norm_ref, length_hyp, length_ref):
139
+ '''
140
+ Compute the cosine similarity of two vectors.
141
+ :param vec_hyp: array of dictionary for vector corresponding to hypothesis
142
+ :param vec_ref: array of dictionary for vector corresponding to reference
143
+ :param norm_hyp: array of float for vector corresponding to hypothesis
144
+ :param norm_ref: array of float for vector corresponding to reference
145
+ :param length_hyp: int containing length of hypothesis
146
+ :param length_ref: int containing length of reference
147
+ :return: array of score for each n-grams cosine similarity
148
+ '''
149
+ delta = float(length_hyp - length_ref)
150
+ # measure consine similarity
151
+ val = np.array([0.0 for _ in range(self.n)])
152
+ for n in range(self.n):
153
+ # ngram
154
+ for (ngram,count) in vec_hyp[n].items():
155
+ # vrama91 : added clipping
156
+ val[n] += min(vec_hyp[n][ngram], vec_ref[n][ngram]) * vec_ref[n][ngram]
157
+
158
+ if (norm_hyp[n] != 0) and (norm_ref[n] != 0):
159
+ val[n] /= (norm_hyp[n]*norm_ref[n])
160
+
161
+ assert(not math.isnan(val[n]))
162
+ # vrama91: added a length based gaussian penalty
163
+ val[n] *= np.e**(-(delta**2)/(2*self.sigma**2))
164
+ return val
165
+
166
+ # compute log reference length
167
+ self.ref_len = np.log(float(len(self.crefs)))
168
+
169
+ scores = []
170
+ for test, refs in zip(self.ctest, self.crefs):
171
+ # compute vector for test captions
172
+ vec, norm, length = counts2vec(test)
173
+ # compute vector for ref captions
174
+ score = np.array([0.0 for _ in range(self.n)])
175
+ for ref in refs:
176
+ vec_ref, norm_ref, length_ref = counts2vec(ref)
177
+ score += sim(vec, vec_ref, norm, norm_ref, length, length_ref)
178
+ # change by vrama91 - mean of ngram scores, instead of sum
179
+ score_avg = np.mean(score)
180
+ # divide by number of references
181
+ score_avg /= len(refs)
182
+ # multiply score by 10
183
+ score_avg *= 10.0
184
+ # append score of an image to the score list
185
+ scores.append(score_avg)
186
+ return scores
187
+
188
+ def compute_score(self, option=None, verbose=0):
189
+ # compute idf
190
+ self.compute_doc_freq()
191
+ # assert to check document frequency
192
+ assert(len(self.ctest) >= max(self.document_frequency.values()))
193
+ # compute cider score
194
+ score = self.compute_cider()
195
+ # debug
196
+ # print score
197
+ return np.mean(np.array(score)), np.array(score)
pycocoevalcap/eval.py ADDED
@@ -0,0 +1,74 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ __author__ = 'tylin'
2
+ from .tokenizer.ptbtokenizer import PTBTokenizer
3
+ from .bleu.bleu import Bleu
4
+ from .meteor.meteor import Meteor
5
+ from .rouge.rouge import Rouge
6
+ from .cider.cider import Cider
7
+
8
+ class COCOEvalCap:
9
+ def __init__(self, coco, cocoRes):
10
+ self.evalImgs = []
11
+ self.eval = {}
12
+ self.imgToEval = {}
13
+ self.coco = coco
14
+ self.cocoRes = cocoRes
15
+ self.params = {'image_id': cocoRes.getImgIds()}
16
+
17
+ def evaluate(self):
18
+ imgIds = self.params['image_id']
19
+ # imgIds = self.coco.getImgIds()
20
+ gts = {}
21
+ res = {}
22
+ for imgId in imgIds:
23
+ gts[imgId] = self.coco.imgToAnns[imgId]
24
+ res[imgId] = self.cocoRes.imgToAnns[imgId]
25
+
26
+ # =================================================
27
+ # Set up scorers
28
+ # =================================================
29
+ print('tokenization...')
30
+ tokenizer = PTBTokenizer()
31
+ gts = tokenizer.tokenize(gts)
32
+ res = tokenizer.tokenize(res)
33
+
34
+ # =================================================
35
+ # Set up scorers
36
+ # =================================================
37
+ print('setting up scorers...')
38
+ scorers = [
39
+ (Bleu(4), ["Bleu_1", "Bleu_2", "Bleu_3", "Bleu_4"]),
40
+ (Meteor(),"METEOR"),
41
+ (Rouge(), "ROUGE_L"),
42
+ (Cider(), "CIDEr")
43
+ ]
44
+
45
+ # =================================================
46
+ # Compute scores
47
+ # =================================================
48
+ eval = {}
49
+ for scorer, method in scorers:
50
+ print('computing %s score...'%(scorer.method()))
51
+ score, scores = scorer.compute_score(gts, res)
52
+ if type(method) == list:
53
+ for sc, scs, m in zip(score, scores, method):
54
+ self.setEval(sc, m)
55
+ self.setImgToEvalImgs(scs, imgIds, m)
56
+ print("%s: %0.3f"%(m, sc))
57
+ else:
58
+ self.setEval(score, method)
59
+ self.setImgToEvalImgs(scores, imgIds, method)
60
+ print("%s: %0.3f"%(method, score))
61
+ self.setEvalImgs()
62
+
63
+ def setEval(self, score, method):
64
+ self.eval[method] = score
65
+
66
+ def setImgToEvalImgs(self, scores, imgIds, method):
67
+ for imgId, score in zip(imgIds, scores):
68
+ if not imgId in self.imgToEval:
69
+ self.imgToEval[imgId] = {}
70
+ self.imgToEval[imgId]["image_id"] = imgId
71
+ self.imgToEval[imgId][method] = score
72
+
73
+ def setEvalImgs(self):
74
+ self.evalImgs = [eval for imgId, eval in self.imgToEval.items()]
pycocoevalcap/license.txt ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Copyright (c) 2015, Xinlei Chen, Hao Fang, Tsung-Yi Lin, and Ramakrishna Vedantam
2
+ All rights reserved.
3
+
4
+ Redistribution and use in source and binary forms, with or without
5
+ modification, are permitted provided that the following conditions are met:
6
+
7
+ 1. Redistributions of source code must retain the above copyright notice, this
8
+ list of conditions and the following disclaimer.
9
+ 2. Redistributions in binary form must reproduce the above copyright notice,
10
+ this list of conditions and the following disclaimer in the documentation
11
+ and/or other materials provided with the distribution.
12
+
13
+ THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
14
+ ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
15
+ WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
16
+ DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR
17
+ ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
18
+ (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
19
+ LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND
20
+ ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
21
+ (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
22
+ SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
23
+
24
+ The views and conclusions contained in the software and documentation are those
25
+ of the authors and should not be interpreted as representing official policies,
26
+ either expressed or implied, of the FreeBSD Project.
pycocoevalcap/meteor/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ from .meteor import *
pycocoevalcap/meteor/meteor-1.5.jar ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:1e57b4c72c0830ebe68558f1c799a624e96cbc1b6045c9f6330e26dcff6eafc2
3
+ size 6318693
pycocoevalcap/meteor/meteor.py ADDED
@@ -0,0 +1,88 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+
3
+ # Python wrapper for METEOR implementation, by Xinlei Chen
4
+ # Acknowledge Michael Denkowski for the generous discussion and help
5
+
6
+ # Last modified : Wed 22 May 2019 08:10:00 PM EDT
7
+ # By Sabarish Sivanath
8
+ # To support Python 3
9
+
10
+ import os
11
+ import sys
12
+ import subprocess
13
+ import threading
14
+
15
+ # Assumes meteor-1.5.jar is in the same directory as meteor.py. Change as needed.
16
+ METEOR_JAR = 'meteor-1.5.jar'
17
+
18
+
19
+ # print METEOR_JAR
20
+
21
+ class Meteor:
22
+
23
+ def __init__(self):
24
+ self.meteor_cmd = ['java', '-jar', '-Xmx2G', METEOR_JAR, \
25
+ '-', '-', '-stdio', '-l', 'en', '-norm']
26
+ self.meteor_p = subprocess.Popen(self.meteor_cmd, \
27
+ cwd=os.path.dirname(os.path.abspath(__file__)), \
28
+ stdin=subprocess.PIPE, \
29
+ stdout=subprocess.PIPE, \
30
+ stderr=subprocess.PIPE,
31
+ universal_newlines=True,
32
+ bufsize=1)
33
+ # Used to guarantee thread safety
34
+ self.lock = threading.Lock()
35
+
36
+ def compute_score(self, gts, res):
37
+ assert (gts.keys() == res.keys())
38
+ imgIds = gts.keys()
39
+ scores = []
40
+
41
+ eval_line = 'EVAL'
42
+ self.lock.acquire()
43
+ for i in imgIds:
44
+ assert (len(res[i]) == 1)
45
+ stat = self._stat(res[i][0], gts[i])
46
+ eval_line += ' ||| {}'.format(stat)
47
+
48
+ self.meteor_p.stdin.write('{}\n'.format(eval_line))
49
+ for i in range(0, len(imgIds)):
50
+ scores.append(float(self.meteor_p.stdout.readline().strip()))
51
+ score = float(self.meteor_p.stdout.readline().strip())
52
+ self.lock.release()
53
+
54
+ return score, scores
55
+
56
+ def method(self):
57
+ return "METEOR"
58
+
59
+ def _stat(self, hypothesis_str, reference_list):
60
+ # SCORE ||| reference 1 words ||| reference n words ||| hypothesis words
61
+ hypothesis_str = hypothesis_str.replace('|||', '').replace(' ', ' ')
62
+ score_line = ' ||| '.join(('SCORE', ' ||| '.join(reference_list), hypothesis_str))
63
+ self.meteor_p.stdin.write('{}\n'.format(score_line))
64
+ return self.meteor_p.stdout.readline().strip()
65
+
66
+ def _score(self, hypothesis_str, reference_list):
67
+ self.lock.acquire()
68
+ # SCORE ||| reference 1 words ||| reference n words ||| hypothesis words
69
+ hypothesis_str = hypothesis_str.replace('|||', '').replace(' ', ' ')
70
+ score_line = ' ||| '.join(('SCORE', ' ||| '.join(reference_list), hypothesis_str))
71
+ self.meteor_p.stdin.write('{}\n'.format(score_line))
72
+ stats = self.meteor_p.stdout.readline().strip()
73
+ eval_line = 'EVAL ||| {}'.format(stats)
74
+ # EVAL ||| stats
75
+ self.meteor_p.stdin.write('{}\n'.format(eval_line))
76
+ score = float(self.meteor_p.stdout.readline().strip())
77
+ # bug fix: there are two values returned by the jar file, one average, and one all, so do it twice
78
+ # thanks for Andrej for pointing this out
79
+ score = float(self.meteor_p.stdout.readline().strip())
80
+ self.lock.release()
81
+ return score
82
+
83
+ def __del__(self):
84
+ self.lock.acquire()
85
+ self.meteor_p.stdin.close()
86
+ self.meteor_p.kill()
87
+ self.meteor_p.wait()
88
+ self.lock.release()
pycocoevalcap/rouge/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ from .rouge import *
pycocoevalcap/rouge/rouge.py ADDED
@@ -0,0 +1,105 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+ #
3
+ # File Name : rouge.py
4
+ #
5
+ # Description : Computes ROUGE-L metric as described by Lin and Hovey (2004)
6
+ #
7
+ # Creation Date : 2015-01-07 06:03
8
+ # Author : Ramakrishna Vedantam <vrama91@vt.edu>
9
+
10
+ import numpy as np
11
+ import pdb
12
+
13
+ def my_lcs(string, sub):
14
+ """
15
+ Calculates longest common subsequence for a pair of tokenized strings
16
+ :param string : list of str : tokens from a string split using whitespace
17
+ :param sub : list of str : shorter string, also split using whitespace
18
+ :returns: length (list of int): length of the longest common subsequence between the two strings
19
+
20
+ Note: my_lcs only gives length of the longest common subsequence, not the actual LCS
21
+ """
22
+ if(len(string)< len(sub)):
23
+ sub, string = string, sub
24
+
25
+ lengths = [[0 for i in range(0,len(sub)+1)] for j in range(0,len(string)+1)]
26
+
27
+ for j in range(1,len(sub)+1):
28
+ for i in range(1,len(string)+1):
29
+ if(string[i-1] == sub[j-1]):
30
+ lengths[i][j] = lengths[i-1][j-1] + 1
31
+ else:
32
+ lengths[i][j] = max(lengths[i-1][j] , lengths[i][j-1])
33
+
34
+ return lengths[len(string)][len(sub)]
35
+
36
+ class Rouge():
37
+ '''
38
+ Class for computing ROUGE-L score for a set of candidate sentences for the MS COCO test set
39
+
40
+ '''
41
+ def __init__(self):
42
+ # vrama91: updated the value below based on discussion with Hovey
43
+ self.beta = 1.2
44
+
45
+ def calc_score(self, candidate, refs):
46
+ """
47
+ Compute ROUGE-L score given one candidate and references for an image
48
+ :param candidate: str : candidate sentence to be evaluated
49
+ :param refs: list of str : COCO reference sentences for the particular image to be evaluated
50
+ :returns score: int (ROUGE-L score for the candidate evaluated against references)
51
+ """
52
+ assert(len(candidate)==1)
53
+ assert(len(refs)>0)
54
+ prec = []
55
+ rec = []
56
+
57
+ # split into tokens
58
+ token_c = candidate[0].split(" ")
59
+
60
+ for reference in refs:
61
+ # split into tokens
62
+ token_r = reference.split(" ")
63
+ # compute the longest common subsequence
64
+ lcs = my_lcs(token_r, token_c)
65
+ prec.append(lcs/float(len(token_c)))
66
+ rec.append(lcs/float(len(token_r)))
67
+
68
+ prec_max = max(prec)
69
+ rec_max = max(rec)
70
+
71
+ if(prec_max!=0 and rec_max !=0):
72
+ score = ((1 + self.beta**2)*prec_max*rec_max)/float(rec_max + self.beta**2*prec_max)
73
+ else:
74
+ score = 0.0
75
+ return score
76
+
77
+ def compute_score(self, gts, res):
78
+ """
79
+ Computes Rouge-L score given a set of reference and candidate sentences for the dataset
80
+ Invoked by evaluate_captions.py
81
+ :param hypo_for_image: dict : candidate / test sentences with "image name" key and "tokenized sentences" as values
82
+ :param ref_for_image: dict : reference MS-COCO sentences with "image name" key and "tokenized sentences" as values
83
+ :returns: average_score: float (mean ROUGE-L score computed by averaging scores for all the images)
84
+ """
85
+ assert(gts.keys() == res.keys())
86
+ imgIds = gts.keys()
87
+
88
+ score = []
89
+ for id in imgIds:
90
+ hypo = res[id]
91
+ ref = gts[id]
92
+
93
+ score.append(self.calc_score(hypo, ref))
94
+
95
+ # Sanity check.
96
+ assert(type(hypo) is list)
97
+ assert(len(hypo) == 1)
98
+ assert(type(ref) is list)
99
+ assert(len(ref) > 0)
100
+
101
+ average_score = np.mean(np.array(score))
102
+ return average_score, np.array(score)
103
+
104
+ def method(self):
105
+ return "Rouge"
pycocoevalcap/tokenizer/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ __author__ = 'hfang'
pycocoevalcap/tokenizer/ptbtokenizer.py ADDED
@@ -0,0 +1,76 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+ #
3
+ # File Name : ptbtokenizer.py
4
+ #
5
+ # Description : Do the PTB Tokenization and remove punctuations.
6
+ #
7
+ # Creation Date : 29-12-2014
8
+ # Last Modified : Thu Mar 19 09:53:35 2015
9
+ # Authors : Hao Fang <hfang@uw.edu> and Tsung-Yi Lin <tl483@cornell.edu>
10
+
11
+ import os
12
+ import sys
13
+ import subprocess
14
+ import tempfile
15
+ import itertools
16
+
17
+
18
+ # Last modified : Wed 22 May 2019 08:10:00 PM EDT
19
+ # By Sabarish Sivanath
20
+ # To support Python 3
21
+
22
+ # path to the stanford corenlp jar
23
+ STANFORD_CORENLP_3_4_1_JAR = 'stanford-corenlp-3.4.1.jar'
24
+
25
+ # punctuations to be removed from the sentences
26
+ PUNCTUATIONS = ["''", "'", "``", "`", "-LRB-", "-RRB-", "-LCB-", "-RCB-", \
27
+ ".", "?", "!", ",", ":", "-", "--", "...", ";"]
28
+
29
+ class PTBTokenizer:
30
+ """Python wrapper of Stanford PTBTokenizer"""
31
+
32
+ def tokenize(self, captions_for_image):
33
+ cmd = ['java', '-cp', STANFORD_CORENLP_3_4_1_JAR, \
34
+ 'edu.stanford.nlp.process.PTBTokenizer', \
35
+ '-preserveLines', '-lowerCase']
36
+
37
+ # ======================================================
38
+ # prepare data for PTB Tokenizer
39
+ # ======================================================
40
+ final_tokenized_captions_for_image = {}
41
+ image_id = [k for k, v in captions_for_image.items() for _ in range(len(v))]
42
+ sentences = '\n'.join([c['caption'].replace('\n', ' ') for k, v in captions_for_image.items() for c in v])
43
+
44
+ # ======================================================
45
+ # save sentences to temporary file
46
+ # ======================================================
47
+ path_to_jar_dirname=os.path.dirname(os.path.abspath(__file__))
48
+ tmp_file = tempfile.NamedTemporaryFile(delete=False, dir=path_to_jar_dirname)
49
+ tmp_file.write(sentences.encode('utf-8'))
50
+ tmp_file.close()
51
+
52
+ # ======================================================
53
+ # tokenize sentence
54
+ # ======================================================
55
+ cmd.append(os.path.basename(tmp_file.name))
56
+ p_tokenizer = subprocess.Popen(cmd,
57
+ cwd=path_to_jar_dirname,
58
+ stdout=subprocess.PIPE,
59
+ universal_newlines = True,
60
+ bufsize = 1)
61
+ token_lines = p_tokenizer.communicate(input=sentences.rstrip())[0]
62
+ lines = token_lines.split('\n')
63
+ # remove temp file
64
+ os.remove(tmp_file.name)
65
+
66
+ # ======================================================
67
+ # create dictionary for tokenized captions
68
+ # ======================================================
69
+ for k, line in zip(image_id, lines):
70
+ if not k in final_tokenized_captions_for_image:
71
+ final_tokenized_captions_for_image[k] = []
72
+ tokenized_caption = ' '.join([w for w in line.rstrip().split(' ') \
73
+ if w not in PUNCTUATIONS])
74
+ final_tokenized_captions_for_image[k].append(tokenized_caption)
75
+
76
+ return final_tokenized_captions_for_image