Spaces:
Sleeping
Sleeping
fenglinliu
commited on
Commit
•
6e32a75
1
Parent(s):
c168557
Upload 55 files
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- .gitattributes +3 -0
- LICENSE +201 -0
- PromptNet.py +114 -0
- app.py +121 -0
- ckpts/few-shot.pth +3 -0
- data/annotation.json +3 -0
- decoder_config/decoder_config.pkl +3 -0
- example_figs/example_fig1.jpg.png +0 -0
- example_figs/example_fig2.jpg.jpg +0 -0
- example_figs/example_fig3.jpg.png +0 -0
- inference.py +110 -0
- models/models.py +125 -0
- models/r2gen.py +63 -0
- modules/att_model.py +319 -0
- modules/att_models.py +120 -0
- modules/caption_model.py +401 -0
- modules/config.pkl +3 -0
- modules/dataloader.py +59 -0
- modules/dataloaders.py +62 -0
- modules/dataset.py +68 -0
- modules/datasets.py +57 -0
- modules/decoder.py +50 -0
- modules/encoder_decoder.py +391 -0
- modules/loss.py +22 -0
- modules/metrics.py +33 -0
- modules/optimizers.py +18 -0
- modules/tester.py +144 -0
- modules/tokenizers.py +95 -0
- modules/trainer.py +255 -0
- modules/utils.py +55 -0
- modules/visual_extractor.py +53 -0
- prompt/prompt.pth +3 -0
- pycocoevalcap/README.md +23 -0
- pycocoevalcap/__init__.py +1 -0
- pycocoevalcap/bleu/LICENSE +19 -0
- pycocoevalcap/bleu/__init__.py +1 -0
- pycocoevalcap/bleu/bleu.py +57 -0
- pycocoevalcap/bleu/bleu_scorer.py +268 -0
- pycocoevalcap/cider/__init__.py +1 -0
- pycocoevalcap/cider/cider.py +55 -0
- pycocoevalcap/cider/cider_scorer.py +197 -0
- pycocoevalcap/eval.py +74 -0
- pycocoevalcap/license.txt +26 -0
- pycocoevalcap/meteor/__init__.py +1 -0
- pycocoevalcap/meteor/meteor-1.5.jar +3 -0
- pycocoevalcap/meteor/meteor.py +88 -0
- pycocoevalcap/rouge/__init__.py +1 -0
- pycocoevalcap/rouge/rouge.py +105 -0
- pycocoevalcap/tokenizer/__init__.py +1 -0
- pycocoevalcap/tokenizer/ptbtokenizer.py +76 -0
.gitattributes
CHANGED
@@ -33,3 +33,6 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
|
|
33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
36 |
+
data/annotation.json filter=lfs diff=lfs merge=lfs -text
|
37 |
+
pycocoevalcap/meteor/meteor-1.5.jar filter=lfs diff=lfs merge=lfs -text
|
38 |
+
pycocoevalcap/tokenizer/stanford-corenlp-3.4.1.jar filter=lfs diff=lfs merge=lfs -text
|
LICENSE
ADDED
@@ -0,0 +1,201 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
Apache License
|
2 |
+
Version 2.0, January 2004
|
3 |
+
http://www.apache.org/licenses/
|
4 |
+
|
5 |
+
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
|
6 |
+
|
7 |
+
1. Definitions.
|
8 |
+
|
9 |
+
"License" shall mean the terms and conditions for use, reproduction,
|
10 |
+
and distribution as defined by Sections 1 through 9 of this document.
|
11 |
+
|
12 |
+
"Licensor" shall mean the copyright owner or entity authorized by
|
13 |
+
the copyright owner that is granting the License.
|
14 |
+
|
15 |
+
"Legal Entity" shall mean the union of the acting entity and all
|
16 |
+
other entities that control, are controlled by, or are under common
|
17 |
+
control with that entity. For the purposes of this definition,
|
18 |
+
"control" means (i) the power, direct or indirect, to cause the
|
19 |
+
direction or management of such entity, whether by contract or
|
20 |
+
otherwise, or (ii) ownership of fifty percent (50%) or more of the
|
21 |
+
outstanding shares, or (iii) beneficial ownership of such entity.
|
22 |
+
|
23 |
+
"You" (or "Your") shall mean an individual or Legal Entity
|
24 |
+
exercising permissions granted by this License.
|
25 |
+
|
26 |
+
"Source" form shall mean the preferred form for making modifications,
|
27 |
+
including but not limited to software source code, documentation
|
28 |
+
source, and configuration files.
|
29 |
+
|
30 |
+
"Object" form shall mean any form resulting from mechanical
|
31 |
+
transformation or translation of a Source form, including but
|
32 |
+
not limited to compiled object code, generated documentation,
|
33 |
+
and conversions to other media types.
|
34 |
+
|
35 |
+
"Work" shall mean the work of authorship, whether in Source or
|
36 |
+
Object form, made available under the License, as indicated by a
|
37 |
+
copyright notice that is included in or attached to the work
|
38 |
+
(an example is provided in the Appendix below).
|
39 |
+
|
40 |
+
"Derivative Works" shall mean any work, whether in Source or Object
|
41 |
+
form, that is based on (or derived from) the Work and for which the
|
42 |
+
editorial revisions, annotations, elaborations, or other modifications
|
43 |
+
represent, as a whole, an original work of authorship. For the purposes
|
44 |
+
of this License, Derivative Works shall not include works that remain
|
45 |
+
separable from, or merely link (or bind by name) to the interfaces of,
|
46 |
+
the Work and Derivative Works thereof.
|
47 |
+
|
48 |
+
"Contribution" shall mean any work of authorship, including
|
49 |
+
the original version of the Work and any modifications or additions
|
50 |
+
to that Work or Derivative Works thereof, that is intentionally
|
51 |
+
submitted to Licensor for inclusion in the Work by the copyright owner
|
52 |
+
or by an individual or Legal Entity authorized to submit on behalf of
|
53 |
+
the copyright owner. For the purposes of this definition, "submitted"
|
54 |
+
means any form of electronic, verbal, or written communication sent
|
55 |
+
to the Licensor or its representatives, including but not limited to
|
56 |
+
communication on electronic mailing lists, source code control systems,
|
57 |
+
and issue tracking systems that are managed by, or on behalf of, the
|
58 |
+
Licensor for the purpose of discussing and improving the Work, but
|
59 |
+
excluding communication that is conspicuously marked or otherwise
|
60 |
+
designated in writing by the copyright owner as "Not a Contribution."
|
61 |
+
|
62 |
+
"Contributor" shall mean Licensor and any individual or Legal Entity
|
63 |
+
on behalf of whom a Contribution has been received by Licensor and
|
64 |
+
subsequently incorporated within the Work.
|
65 |
+
|
66 |
+
2. Grant of Copyright License. Subject to the terms and conditions of
|
67 |
+
this License, each Contributor hereby grants to You a perpetual,
|
68 |
+
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
69 |
+
copyright license to reproduce, prepare Derivative Works of,
|
70 |
+
publicly display, publicly perform, sublicense, and distribute the
|
71 |
+
Work and such Derivative Works in Source or Object form.
|
72 |
+
|
73 |
+
3. Grant of Patent License. Subject to the terms and conditions of
|
74 |
+
this License, each Contributor hereby grants to You a perpetual,
|
75 |
+
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
76 |
+
(except as stated in this section) patent license to make, have made,
|
77 |
+
use, offer to sell, sell, import, and otherwise transfer the Work,
|
78 |
+
where such license applies only to those patent claims licensable
|
79 |
+
by such Contributor that are necessarily infringed by their
|
80 |
+
Contribution(s) alone or by combination of their Contribution(s)
|
81 |
+
with the Work to which such Contribution(s) was submitted. If You
|
82 |
+
institute patent litigation against any entity (including a
|
83 |
+
cross-claim or counterclaim in a lawsuit) alleging that the Work
|
84 |
+
or a Contribution incorporated within the Work constitutes direct
|
85 |
+
or contributory patent infringement, then any patent licenses
|
86 |
+
granted to You under this License for that Work shall terminate
|
87 |
+
as of the date such litigation is filed.
|
88 |
+
|
89 |
+
4. Redistribution. You may reproduce and distribute copies of the
|
90 |
+
Work or Derivative Works thereof in any medium, with or without
|
91 |
+
modifications, and in Source or Object form, provided that You
|
92 |
+
meet the following conditions:
|
93 |
+
|
94 |
+
(a) You must give any other recipients of the Work or
|
95 |
+
Derivative Works a copy of this License; and
|
96 |
+
|
97 |
+
(b) You must cause any modified files to carry prominent notices
|
98 |
+
stating that You changed the files; and
|
99 |
+
|
100 |
+
(c) You must retain, in the Source form of any Derivative Works
|
101 |
+
that You distribute, all copyright, patent, trademark, and
|
102 |
+
attribution notices from the Source form of the Work,
|
103 |
+
excluding those notices that do not pertain to any part of
|
104 |
+
the Derivative Works; and
|
105 |
+
|
106 |
+
(d) If the Work includes a "NOTICE" text file as part of its
|
107 |
+
distribution, then any Derivative Works that You distribute must
|
108 |
+
include a readable copy of the attribution notices contained
|
109 |
+
within such NOTICE file, excluding those notices that do not
|
110 |
+
pertain to any part of the Derivative Works, in at least one
|
111 |
+
of the following places: within a NOTICE text file distributed
|
112 |
+
as part of the Derivative Works; within the Source form or
|
113 |
+
documentation, if provided along with the Derivative Works; or,
|
114 |
+
within a display generated by the Derivative Works, if and
|
115 |
+
wherever such third-party notices normally appear. The contents
|
116 |
+
of the NOTICE file are for informational purposes only and
|
117 |
+
do not modify the License. You may add Your own attribution
|
118 |
+
notices within Derivative Works that You distribute, alongside
|
119 |
+
or as an addendum to the NOTICE text from the Work, provided
|
120 |
+
that such additional attribution notices cannot be construed
|
121 |
+
as modifying the License.
|
122 |
+
|
123 |
+
You may add Your own copyright statement to Your modifications and
|
124 |
+
may provide additional or different license terms and conditions
|
125 |
+
for use, reproduction, or distribution of Your modifications, or
|
126 |
+
for any such Derivative Works as a whole, provided Your use,
|
127 |
+
reproduction, and distribution of the Work otherwise complies with
|
128 |
+
the conditions stated in this License.
|
129 |
+
|
130 |
+
5. Submission of Contributions. Unless You explicitly state otherwise,
|
131 |
+
any Contribution intentionally submitted for inclusion in the Work
|
132 |
+
by You to the Licensor shall be under the terms and conditions of
|
133 |
+
this License, without any additional terms or conditions.
|
134 |
+
Notwithstanding the above, nothing herein shall supersede or modify
|
135 |
+
the terms of any separate license agreement you may have executed
|
136 |
+
with Licensor regarding such Contributions.
|
137 |
+
|
138 |
+
6. Trademarks. This License does not grant permission to use the trade
|
139 |
+
names, trademarks, service marks, or product names of the Licensor,
|
140 |
+
except as required for reasonable and customary use in describing the
|
141 |
+
origin of the Work and reproducing the content of the NOTICE file.
|
142 |
+
|
143 |
+
7. Disclaimer of Warranty. Unless required by applicable law or
|
144 |
+
agreed to in writing, Licensor provides the Work (and each
|
145 |
+
Contributor provides its Contributions) on an "AS IS" BASIS,
|
146 |
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
|
147 |
+
implied, including, without limitation, any warranties or conditions
|
148 |
+
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
|
149 |
+
PARTICULAR PURPOSE. You are solely responsible for determining the
|
150 |
+
appropriateness of using or redistributing the Work and assume any
|
151 |
+
risks associated with Your exercise of permissions under this License.
|
152 |
+
|
153 |
+
8. Limitation of Liability. In no event and under no legal theory,
|
154 |
+
whether in tort (including negligence), contract, or otherwise,
|
155 |
+
unless required by applicable law (such as deliberate and grossly
|
156 |
+
negligent acts) or agreed to in writing, shall any Contributor be
|
157 |
+
liable to You for damages, including any direct, indirect, special,
|
158 |
+
incidental, or consequential damages of any character arising as a
|
159 |
+
result of this License or out of the use or inability to use the
|
160 |
+
Work (including but not limited to damages for loss of goodwill,
|
161 |
+
work stoppage, computer failure or malfunction, or any and all
|
162 |
+
other commercial damages or losses), even if such Contributor
|
163 |
+
has been advised of the possibility of such damages.
|
164 |
+
|
165 |
+
9. Accepting Warranty or Additional Liability. While redistributing
|
166 |
+
the Work or Derivative Works thereof, You may choose to offer,
|
167 |
+
and charge a fee for, acceptance of support, warranty, indemnity,
|
168 |
+
or other liability obligations and/or rights consistent with this
|
169 |
+
License. However, in accepting such obligations, You may act only
|
170 |
+
on Your own behalf and on Your sole responsibility, not on behalf
|
171 |
+
of any other Contributor, and only if You agree to indemnify,
|
172 |
+
defend, and hold each Contributor harmless for any liability
|
173 |
+
incurred by, or claims asserted against, such Contributor by reason
|
174 |
+
of your accepting any such warranty or additional liability.
|
175 |
+
|
176 |
+
END OF TERMS AND CONDITIONS
|
177 |
+
|
178 |
+
APPENDIX: How to apply the Apache License to your work.
|
179 |
+
|
180 |
+
To apply the Apache License to your work, attach the following
|
181 |
+
boilerplate notice, with the fields enclosed by brackets "[]"
|
182 |
+
replaced with your own identifying information. (Don't include
|
183 |
+
the brackets!) The text should be enclosed in the appropriate
|
184 |
+
comment syntax for the file format. We also recommend that a
|
185 |
+
file or class name and description of purpose be included on the
|
186 |
+
same "printed page" as the copyright notice for easier
|
187 |
+
identification within third-party archives.
|
188 |
+
|
189 |
+
Copyright [yyyy] [name of copyright owner]
|
190 |
+
|
191 |
+
Licensed under the Apache License, Version 2.0 (the "License");
|
192 |
+
you may not use this file except in compliance with the License.
|
193 |
+
You may obtain a copy of the License at
|
194 |
+
|
195 |
+
http://www.apache.org/licenses/LICENSE-2.0
|
196 |
+
|
197 |
+
Unless required by applicable law or agreed to in writing, software
|
198 |
+
distributed under the License is distributed on an "AS IS" BASIS,
|
199 |
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
200 |
+
See the License for the specific language governing permissions and
|
201 |
+
limitations under the License.
|
PromptNet.py
ADDED
@@ -0,0 +1,114 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import argparse
|
3 |
+
from modules.dataloader import R2DataLoader
|
4 |
+
from modules.tokenizers import Tokenizer
|
5 |
+
from modules.loss import compute_loss
|
6 |
+
from modules.metrics import compute_scores
|
7 |
+
from modules.optimizers import build_optimizer, build_lr_scheduler
|
8 |
+
from models.models import MedCapModel
|
9 |
+
from modules.trainer import Trainer
|
10 |
+
import numpy as np
|
11 |
+
|
12 |
+
def main():
|
13 |
+
parser = argparse.ArgumentParser()
|
14 |
+
|
15 |
+
# Data input Settings
|
16 |
+
parser.add_argument('--json_path', default='data/mimic_cxr/annotation.json',
|
17 |
+
help='Path to the json file')
|
18 |
+
parser.add_argument('--image_dir', default='data/mimic_cxr/images/',
|
19 |
+
help='Directory of images')
|
20 |
+
|
21 |
+
# Dataloader Settings
|
22 |
+
parser.add_argument('--dataset', default='mimic_cxr', help='dataset for training MedCap')
|
23 |
+
parser.add_argument('--bs', type=int, default=16)
|
24 |
+
parser.add_argument('--threshold', type=int, default=10, help='the cut off frequency for the words.')
|
25 |
+
parser.add_argument('--num_workers', type=int, default=2, help='the number of workers for dataloader.')
|
26 |
+
parser.add_argument('--max_seq_length', type=int, default=1024, help='the maximum sequence length of the reports.')
|
27 |
+
|
28 |
+
#Trainer Settings
|
29 |
+
parser.add_argument('--epochs', type=int, default=30)
|
30 |
+
parser.add_argument('--n_gpu', type=int, default=1, help='the number of gpus to be used.')
|
31 |
+
parser.add_argument('--save_dir', type=str, default='results/mimic_cxr/', help='the patch to save the models.')
|
32 |
+
parser.add_argument('--record_dir', type=str, default='./record_dir/',
|
33 |
+
help='the patch to save the results of experiments.')
|
34 |
+
parser.add_argument('--log_period', type=int, default=1000, help='the logging interval (in batches).')
|
35 |
+
parser.add_argument('--save_period', type=int, default=1)
|
36 |
+
parser.add_argument('--monitor_mode', type=str, default='max', choices=['min', 'max'], help='whether to max or min the metric.')
|
37 |
+
parser.add_argument('--monitor_metric', type=str, default='BLEU_4', help='the metric to be monitored.')
|
38 |
+
parser.add_argument('--early_stop', type=int, default=50, help='the patience of training.')
|
39 |
+
|
40 |
+
# Training related
|
41 |
+
parser.add_argument('--noise_inject', default='no', choices=['yes', 'no'])
|
42 |
+
|
43 |
+
# Sample related
|
44 |
+
parser.add_argument('--sample_method', type=str, default='greedy', help='the sample methods to sample a report.')
|
45 |
+
parser.add_argument('--prompt', default='/prompt/prompt.pt')
|
46 |
+
parser.add_argument('--prompt_load', default='no',choices=['yes','no'])
|
47 |
+
|
48 |
+
# Optimization
|
49 |
+
parser.add_argument('--optim', type=str, default='Adam', help='the type of the optimizer.')
|
50 |
+
parser.add_argument('--lr_ve', type=float, default=1e-5, help='the learning rate for the visual extractor.')
|
51 |
+
parser.add_argument('--lr_ed', type=float, default=5e-4, help='the learning rate for the remaining parameters.')
|
52 |
+
parser.add_argument('--weight_decay', type=float, default=5e-5, help='the weight decay.')
|
53 |
+
parser.add_argument('--adam_betas', type=tuple, default=(0.9, 0.98), help='the weight decay.')
|
54 |
+
parser.add_argument('--adam_eps', type=float, default=1e-9, help='the weight decay.')
|
55 |
+
parser.add_argument('--amsgrad', type=bool, default=True, help='.')
|
56 |
+
parser.add_argument('--noamopt_warmup', type=int, default=5000, help='.')
|
57 |
+
parser.add_argument('--noamopt_factor', type=int, default=1, help='.')
|
58 |
+
|
59 |
+
# Learning Rate Scheduler
|
60 |
+
parser.add_argument('--lr_scheduler', type=str, default='StepLR', help='the type of the learning rate scheduler.')
|
61 |
+
parser.add_argument('--step_size', type=int, default=50, help='the step size of the learning rate scheduler.')
|
62 |
+
parser.add_argument('--gamma', type=float, default=0.1, help='the gamma of the learning rate scheduler.')
|
63 |
+
|
64 |
+
# Others
|
65 |
+
parser.add_argument('--seed', type=int, default=9153, help='.')
|
66 |
+
parser.add_argument('--resume', type=str, help='whether to resume the training from existing checkpoints.')
|
67 |
+
parser.add_argument('--train_mode', default='base', choices=['base', 'fine-tuning'],
|
68 |
+
help='Training mode: base (autoencoding) or fine-tuning (full supervised training or fine-tuned on downstream datasets)')
|
69 |
+
parser.add_argument('--F_version', default='v1', choices=['v1', 'v2'],)
|
70 |
+
parser.add_argument('--clip_update', default='no' , choices=['yes','no'])
|
71 |
+
|
72 |
+
# Fine-tuning
|
73 |
+
parser.add_argument('--random_init', default='yes', choices=['yes', 'no'],
|
74 |
+
help='Whether to load the pre-trained weights for fine-tuning.')
|
75 |
+
parser.add_argument('--weight_path', default='path_to_default_weights', type=str,
|
76 |
+
help='Path to the pre-trained model weights.')
|
77 |
+
args = parser.parse_args()
|
78 |
+
|
79 |
+
# fix random seeds
|
80 |
+
torch.manual_seed(args.seed)
|
81 |
+
torch.backends.cudnn.deterministic = True
|
82 |
+
torch.backends.cudnn.benchmark = False
|
83 |
+
np.random.seed(args.seed)
|
84 |
+
|
85 |
+
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
86 |
+
|
87 |
+
# create tokenizer
|
88 |
+
tokenizer = Tokenizer(args)
|
89 |
+
|
90 |
+
# create data loader
|
91 |
+
train_dataloader = R2DataLoader(args, tokenizer, split='train', shuffle=True)
|
92 |
+
val_dataloader = R2DataLoader(args, tokenizer, split='val', shuffle=False)
|
93 |
+
test_dataloader = R2DataLoader(args, tokenizer, split='test', shuffle=False)
|
94 |
+
|
95 |
+
# get function handles of loss and metrics
|
96 |
+
criterion = compute_loss
|
97 |
+
metrics = compute_scores
|
98 |
+
model = MedCapModel(args, tokenizer)
|
99 |
+
|
100 |
+
if args.train_mode == 'fine-tuning' and args.random_init == 'no':
|
101 |
+
# Load weights from the specified path
|
102 |
+
checkpoint = torch.load(args.weight_path)
|
103 |
+
model.load_state_dict(checkpoint)
|
104 |
+
|
105 |
+
# build optimizer, learning rate scheduler
|
106 |
+
optimizer = build_optimizer(args, model)
|
107 |
+
lr_scheduler = build_lr_scheduler(args, optimizer)
|
108 |
+
|
109 |
+
# build trainer and start to train
|
110 |
+
trainer = Trainer(model, criterion, metrics, optimizer, args, lr_scheduler, train_dataloader, val_dataloader, test_dataloader)
|
111 |
+
trainer.train()
|
112 |
+
|
113 |
+
if __name__ == '__main__':
|
114 |
+
main()
|
app.py
ADDED
@@ -0,0 +1,121 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import gradio as gr
|
2 |
+
import torch
|
3 |
+
from PIL import Image
|
4 |
+
from models.r2gen import R2GenModel
|
5 |
+
from modules.tokenizers import Tokenizer
|
6 |
+
import argparse
|
7 |
+
|
8 |
+
# Assuming you have a predefined configuration function for model args
|
9 |
+
def get_model_args():
|
10 |
+
parser = argparse.ArgumentParser()
|
11 |
+
|
12 |
+
# Model loader settings
|
13 |
+
parser.add_argument('--load', type=str, default='ckpts/few-shot.pth', help='the path to the model weights.')
|
14 |
+
parser.add_argument('--prompt', type=str, default='prompt/prompt.pth', help='the path to the prompt weights.')
|
15 |
+
|
16 |
+
# Data input settings
|
17 |
+
parser.add_argument('--image_path', type=str, default='example_figs/example_fig1.jpg', help='the path to the test image.')
|
18 |
+
parser.add_argument('--image_dir', type=str, default='data/images/', help='the path to the directory containing the data.')
|
19 |
+
parser.add_argument('--ann_path', type=str, default='data/annotation.json', help='the path to the directory containing the data.')
|
20 |
+
|
21 |
+
# Data loader settings
|
22 |
+
parser.add_argument('--dataset_name', type=str, default='mimic_cxr', help='the dataset to be used.')
|
23 |
+
parser.add_argument('--max_seq_length', type=int, default=60, help='the maximum sequence length of the reports.')
|
24 |
+
parser.add_argument('--threshold', type=int, default=3, help='the cut off frequency for the words.')
|
25 |
+
parser.add_argument('--num_workers', type=int, default=2, help='the number of workers for dataloader.')
|
26 |
+
parser.add_argument('--batch_size', type=int, default=16, help='the number of samples for a batch')
|
27 |
+
|
28 |
+
# Model settings (for visual extractor)
|
29 |
+
parser.add_argument('--visual_extractor', type=str, default='resnet101', help='the visual extractor to be used.')
|
30 |
+
parser.add_argument('--visual_extractor_pretrained', type=bool, default=True, help='whether to load the pretrained visual extractor')
|
31 |
+
|
32 |
+
# Model settings (for Transformer)
|
33 |
+
parser.add_argument('--d_model', type=int, default=512, help='the dimension of Transformer.')
|
34 |
+
parser.add_argument('--d_ff', type=int, default=512, help='the dimension of FFN.')
|
35 |
+
parser.add_argument('--d_vf', type=int, default=2048, help='the dimension of the patch features.')
|
36 |
+
parser.add_argument('--num_heads', type=int, default=8, help='the number of heads in Transformer.')
|
37 |
+
parser.add_argument('--num_layers', type=int, default=3, help='the number of layers of Transformer.')
|
38 |
+
parser.add_argument('--dropout', type=float, default=0.1, help='the dropout rate of Transformer.')
|
39 |
+
parser.add_argument('--logit_layers', type=int, default=1, help='the number of the logit layer.')
|
40 |
+
parser.add_argument('--bos_idx', type=int, default=0, help='the index of <bos>.')
|
41 |
+
parser.add_argument('--eos_idx', type=int, default=0, help='the index of <eos>.')
|
42 |
+
parser.add_argument('--pad_idx', type=int, default=0, help='the index of <pad>.')
|
43 |
+
parser.add_argument('--use_bn', type=int, default=0, help='whether to use batch normalization.')
|
44 |
+
parser.add_argument('--drop_prob_lm', type=float, default=0.5, help='the dropout rate of the output layer.')
|
45 |
+
# for Relational Memory
|
46 |
+
parser.add_argument('--rm_num_slots', type=int, default=3, help='the number of memory slots.')
|
47 |
+
parser.add_argument('--rm_num_heads', type=int, default=8, help='the numebr of heads in rm.')
|
48 |
+
parser.add_argument('--rm_d_model', type=int, default=512, help='the dimension of rm.')
|
49 |
+
|
50 |
+
# Sample related
|
51 |
+
parser.add_argument('--sample_method', type=str, default='beam_search', help='the sample methods to sample a report.')
|
52 |
+
parser.add_argument('--beam_size', type=int, default=3, help='the beam size when beam searching.')
|
53 |
+
parser.add_argument('--temperature', type=float, default=1.0, help='the temperature when sampling.')
|
54 |
+
parser.add_argument('--sample_n', type=int, default=1, help='the sample number per image.')
|
55 |
+
parser.add_argument('--group_size', type=int, default=1, help='the group size.')
|
56 |
+
parser.add_argument('--output_logsoftmax', type=int, default=1, help='whether to output the probabilities.')
|
57 |
+
parser.add_argument('--decoding_constraint', type=int, default=0, help='whether decoding constraint.')
|
58 |
+
parser.add_argument('--block_trigrams', type=int, default=1, help='whether to use block trigrams.')
|
59 |
+
|
60 |
+
# Trainer settings
|
61 |
+
parser.add_argument('--n_gpu', type=int, default=1, help='the number of gpus to be used.')
|
62 |
+
parser.add_argument('--epochs', type=int, default=100, help='the number of training epochs.')
|
63 |
+
parser.add_argument('--save_dir', type=str, default='results/iu_xray', help='the patch to save the models.')
|
64 |
+
parser.add_argument('--record_dir', type=str, default='records/', help='the patch to save the results of experiments')
|
65 |
+
parser.add_argument('--save_period', type=int, default=1, help='the saving period.')
|
66 |
+
parser.add_argument('--monitor_mode', type=str, default='max', choices=['min', 'max'], help='whether to max or min the metric.')
|
67 |
+
parser.add_argument('--monitor_metric', type=str, default='BLEU_4', help='the metric to be monitored.')
|
68 |
+
parser.add_argument('--early_stop', type=int, default=50, help='the patience of training.')
|
69 |
+
|
70 |
+
# Optimization
|
71 |
+
parser.add_argument('--optim', type=str, default='Adam', help='the type of the optimizer.')
|
72 |
+
parser.add_argument('--lr_ve', type=float, default=5e-5, help='the learning rate for the visual extractor.')
|
73 |
+
parser.add_argument('--lr_ed', type=float, default=1e-4, help='the learning rate for the remaining parameters.')
|
74 |
+
parser.add_argument('--weight_decay', type=float, default=5e-5, help='the weight decay.')
|
75 |
+
parser.add_argument('--amsgrad', type=bool, default=True, help='.')
|
76 |
+
|
77 |
+
# Learning Rate Scheduler
|
78 |
+
parser.add_argument('--lr_scheduler', type=str, default='StepLR', help='the type of the learning rate scheduler.')
|
79 |
+
parser.add_argument('--step_size', type=int, default=50, help='the step size of the learning rate scheduler.')
|
80 |
+
parser.add_argument('--gamma', type=float, default=0.1, help='the gamma of the learning rate scheduler.')
|
81 |
+
|
82 |
+
# Others
|
83 |
+
parser.add_argument('--seed', type=int, default=9233, help='.')
|
84 |
+
parser.add_argument('--resume', type=str, help='whether to resume the training from existing checkpoints.')
|
85 |
+
|
86 |
+
args = parser.parse_args()
|
87 |
+
return args
|
88 |
+
|
89 |
+
def load_model():
|
90 |
+
args = get_model_args()
|
91 |
+
tokenizer = Tokenizer(args)
|
92 |
+
device = 'cuda' if torch.cuda.is_available() else 'cpu' # Determine the device dynamically
|
93 |
+
model = R2GenModel(args, tokenizer).to(device)
|
94 |
+
checkpoint_path = args.load
|
95 |
+
# Ensure the state dict is loaded onto the same device as the model
|
96 |
+
state_dict = torch.load(checkpoint_path, map_location=device)
|
97 |
+
model_state_dict = state_dict['state_dict'] if 'state_dict' in state_dict else state_dict
|
98 |
+
model.load_state_dict(model_state_dict)
|
99 |
+
model.eval()
|
100 |
+
return model, tokenizer
|
101 |
+
|
102 |
+
model, tokenizer = load_model()
|
103 |
+
|
104 |
+
def generate_report(image):
|
105 |
+
image = Image.fromarray(image).convert('RGB')
|
106 |
+
with torch.no_grad():
|
107 |
+
output = model([image], mode='sample')
|
108 |
+
reports = tokenizer.decode_batch(output.cpu().numpy())
|
109 |
+
return reports[0]
|
110 |
+
|
111 |
+
# Define Gradio interface
|
112 |
+
iface = gr.Interface(
|
113 |
+
fn=generate_report,
|
114 |
+
inputs=gr.inputs.Image(), # Define input shape as needed
|
115 |
+
outputs="text",
|
116 |
+
title="PromptNet",
|
117 |
+
description="Upload a medical image for thorax disease reporting."
|
118 |
+
)
|
119 |
+
|
120 |
+
if __name__ == "__main__":
|
121 |
+
iface.launch()
|
ckpts/few-shot.pth
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:fa4c3ef1a822fdca8895f6ad0c73b4f355b036d0d28a8523aaf51f58c7393f38
|
3 |
+
size 1660341639
|
data/annotation.json
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:5d9590de8db89b0c74343a7e2aecba61e8029e15801de10ec4e030be80b62adc
|
3 |
+
size 155745921
|
decoder_config/decoder_config.pkl
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:c454e6bddb15af52c82734f1796391bf3a10a6c5533ea095de06f661ebb858bb
|
3 |
+
size 1744
|
example_figs/example_fig1.jpg.png
ADDED
example_figs/example_fig2.jpg.jpg
ADDED
example_figs/example_fig3.jpg.png
ADDED
inference.py
ADDED
@@ -0,0 +1,110 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from models.r2gen import R2GenModel
|
3 |
+
from PIL import Image
|
4 |
+
from modules.tokenizers import Tokenizer
|
5 |
+
import main
|
6 |
+
import argparse
|
7 |
+
import json
|
8 |
+
import re
|
9 |
+
from collections import Counter
|
10 |
+
|
11 |
+
def parse_agrs():
|
12 |
+
parser = argparse.ArgumentParser()
|
13 |
+
|
14 |
+
# Model loader settings
|
15 |
+
parser.add_argument('--load', type=str, default='ckpt/checkpoint.pth', help='the path to the model weights.')
|
16 |
+
parser.add_argument('--prompt', type=str, default='ckpt/prompt.pth', help='the path to the prompt weights.')
|
17 |
+
|
18 |
+
# Data input settings
|
19 |
+
parser.add_argument('--image_path', type=str, default='example_figs/fig1.jpg', help='the path to the test image.')
|
20 |
+
parser.add_argument('--image_dir', type=str, default='data/images/', help='the path to the directory containing the data.')
|
21 |
+
parser.add_argument('--ann_path', type=str, default='data/annotation.json', help='the path to the directory containing the data.')
|
22 |
+
|
23 |
+
# Data loader settings
|
24 |
+
parser.add_argument('--dataset_name', type=str, default='mimic_cxr', help='the dataset to be used.')
|
25 |
+
parser.add_argument('--max_seq_length', type=int, default=60, help='the maximum sequence length of the reports.')
|
26 |
+
parser.add_argument('--threshold', type=int, default=3, help='the cut off frequency for the words.')
|
27 |
+
parser.add_argument('--num_workers', type=int, default=2, help='the number of workers for dataloader.')
|
28 |
+
parser.add_argument('--batch_size', type=int, default=16, help='the number of samples for a batch')
|
29 |
+
|
30 |
+
# Model settings (for visual extractor)
|
31 |
+
parser.add_argument('--visual_extractor', type=str, default='resnet101', help='the visual extractor to be used.')
|
32 |
+
parser.add_argument('--visual_extractor_pretrained', type=bool, default=True, help='whether to load the pretrained visual extractor')
|
33 |
+
|
34 |
+
# Model settings (for Transformer)
|
35 |
+
parser.add_argument('--d_model', type=int, default=512, help='the dimension of Transformer.')
|
36 |
+
parser.add_argument('--d_ff', type=int, default=512, help='the dimension of FFN.')
|
37 |
+
parser.add_argument('--d_vf', type=int, default=2048, help='the dimension of the patch features.')
|
38 |
+
parser.add_argument('--num_heads', type=int, default=8, help='the number of heads in Transformer.')
|
39 |
+
parser.add_argument('--num_layers', type=int, default=3, help='the number of layers of Transformer.')
|
40 |
+
parser.add_argument('--dropout', type=float, default=0.1, help='the dropout rate of Transformer.')
|
41 |
+
parser.add_argument('--logit_layers', type=int, default=1, help='the number of the logit layer.')
|
42 |
+
parser.add_argument('--bos_idx', type=int, default=0, help='the index of <bos>.')
|
43 |
+
parser.add_argument('--eos_idx', type=int, default=0, help='the index of <eos>.')
|
44 |
+
parser.add_argument('--pad_idx', type=int, default=0, help='the index of <pad>.')
|
45 |
+
parser.add_argument('--use_bn', type=int, default=0, help='whether to use batch normalization.')
|
46 |
+
parser.add_argument('--drop_prob_lm', type=float, default=0.5, help='the dropout rate of the output layer.')
|
47 |
+
# for Relational Memory
|
48 |
+
parser.add_argument('--rm_num_slots', type=int, default=3, help='the number of memory slots.')
|
49 |
+
parser.add_argument('--rm_num_heads', type=int, default=8, help='the numebr of heads in rm.')
|
50 |
+
parser.add_argument('--rm_d_model', type=int, default=512, help='the dimension of rm.')
|
51 |
+
|
52 |
+
# Sample related
|
53 |
+
parser.add_argument('--sample_method', type=str, default='beam_search', help='the sample methods to sample a report.')
|
54 |
+
parser.add_argument('--beam_size', type=int, default=3, help='the beam size when beam searching.')
|
55 |
+
parser.add_argument('--temperature', type=float, default=1.0, help='the temperature when sampling.')
|
56 |
+
parser.add_argument('--sample_n', type=int, default=1, help='the sample number per image.')
|
57 |
+
parser.add_argument('--group_size', type=int, default=1, help='the group size.')
|
58 |
+
parser.add_argument('--output_logsoftmax', type=int, default=1, help='whether to output the probabilities.')
|
59 |
+
parser.add_argument('--decoding_constraint', type=int, default=0, help='whether decoding constraint.')
|
60 |
+
parser.add_argument('--block_trigrams', type=int, default=1, help='whether to use block trigrams.')
|
61 |
+
|
62 |
+
# Trainer settings
|
63 |
+
parser.add_argument('--n_gpu', type=int, default=1, help='the number of gpus to be used.')
|
64 |
+
parser.add_argument('--epochs', type=int, default=100, help='the number of training epochs.')
|
65 |
+
parser.add_argument('--save_dir', type=str, default='results/iu_xray', help='the patch to save the models.')
|
66 |
+
parser.add_argument('--record_dir', type=str, default='records/', help='the patch to save the results of experiments')
|
67 |
+
parser.add_argument('--save_period', type=int, default=1, help='the saving period.')
|
68 |
+
parser.add_argument('--monitor_mode', type=str, default='max', choices=['min', 'max'], help='whether to max or min the metric.')
|
69 |
+
parser.add_argument('--monitor_metric', type=str, default='BLEU_4', help='the metric to be monitored.')
|
70 |
+
parser.add_argument('--early_stop', type=int, default=50, help='the patience of training.')
|
71 |
+
|
72 |
+
# Optimization
|
73 |
+
parser.add_argument('--optim', type=str, default='Adam', help='the type of the optimizer.')
|
74 |
+
parser.add_argument('--lr_ve', type=float, default=5e-5, help='the learning rate for the visual extractor.')
|
75 |
+
parser.add_argument('--lr_ed', type=float, default=1e-4, help='the learning rate for the remaining parameters.')
|
76 |
+
parser.add_argument('--weight_decay', type=float, default=5e-5, help='the weight decay.')
|
77 |
+
parser.add_argument('--amsgrad', type=bool, default=True, help='.')
|
78 |
+
|
79 |
+
# Learning Rate Scheduler
|
80 |
+
parser.add_argument('--lr_scheduler', type=str, default='StepLR', help='the type of the learning rate scheduler.')
|
81 |
+
parser.add_argument('--step_size', type=int, default=50, help='the step size of the learning rate scheduler.')
|
82 |
+
parser.add_argument('--gamma', type=float, default=0.1, help='the gamma of the learning rate scheduler.')
|
83 |
+
|
84 |
+
# Others
|
85 |
+
parser.add_argument('--seed', type=int, default=9233, help='.')
|
86 |
+
parser.add_argument('--resume', type=str, help='whether to resume the training from existing checkpoints.')
|
87 |
+
|
88 |
+
args = parser.parse_args()
|
89 |
+
return args
|
90 |
+
|
91 |
+
|
92 |
+
args = parse_agrs()
|
93 |
+
tokenizer = Tokenizer(args)
|
94 |
+
image_path=args.image_path
|
95 |
+
checkpoint_path = args.load
|
96 |
+
|
97 |
+
image =[Image.open(image_path).convert('RGB')
|
98 |
+
]
|
99 |
+
model=R2GenModel(args ,tokenizer).to('cuda' if torch.cuda.is_available() else 'cpu')
|
100 |
+
|
101 |
+
state_dict = torch.load(checkpoint_path)
|
102 |
+
model_state_dict = state_dict['state_dict']
|
103 |
+
model.load_state_dict(model_state_dict).to('cuda' if torch.cuda.is_available() else 'cpu')
|
104 |
+
|
105 |
+
model.eval()
|
106 |
+
with torch.no_grad():
|
107 |
+
|
108 |
+
output = model(image, mode='sample')
|
109 |
+
reports = model.tokenizer.decode_batch(output.cpu().numpy())
|
110 |
+
print(reports)
|
models/models.py
ADDED
@@ -0,0 +1,125 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
import torch
|
3 |
+
import torch.nn as nn
|
4 |
+
import pickle
|
5 |
+
from typing import Tuple
|
6 |
+
from transformers import GPT2LMHeadModel
|
7 |
+
from modules.decoder import DeCap
|
8 |
+
from medclip import MedCLIPModel, MedCLIPVisionModelViT
|
9 |
+
import math
|
10 |
+
import pdb
|
11 |
+
|
12 |
+
|
13 |
+
class MedCapModel(nn.Module):
|
14 |
+
def __init__(self, args, tokenizer):
|
15 |
+
super(MedCapModel, self).__init__()
|
16 |
+
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
17 |
+
self.args = args
|
18 |
+
self.tokenizer = tokenizer
|
19 |
+
self.model = DeCap(args, tokenizer)
|
20 |
+
|
21 |
+
self.align_model = MedCLIPModel(vision_cls=MedCLIPVisionModelViT)
|
22 |
+
self.align_model.from_pretrained()
|
23 |
+
self.prompt = torch.load(args.prompt)
|
24 |
+
if args.dataset == 'iu_xray':
|
25 |
+
self.forward = self.forward_iu_xray
|
26 |
+
else:
|
27 |
+
self.forward = self.forward_mimic_cxr
|
28 |
+
|
29 |
+
def noise_injection(self, x, variance=0.001, modality_offset=None, dont_norm=False):
|
30 |
+
if variance == 0.0:
|
31 |
+
return x
|
32 |
+
std = math.sqrt(variance)
|
33 |
+
if not dont_norm:
|
34 |
+
x = torch.nn.functional.normalize(x, dim=1)
|
35 |
+
else:
|
36 |
+
x = x + (torch.randn(x.shape) * std) # todo by some conventions multivraiance noise should be devided by sqrt of dim
|
37 |
+
if modality_offset is not None:
|
38 |
+
x = x + modality_offset
|
39 |
+
return torch.nn.functional.normalize(x, dim=1)
|
40 |
+
|
41 |
+
def align_encode_images_iu_xray(self, images):
|
42 |
+
# Split the images
|
43 |
+
image1, image2 = images.unbind(dim=1)
|
44 |
+
# Encode each image
|
45 |
+
feature1 = self.align_model.encode_image(image1)
|
46 |
+
feature2 = self.align_model.encode_image(image2)
|
47 |
+
if self.args.prompt_load == 'yes':
|
48 |
+
sim_1 = feature1 @ self.prompt.T.float()
|
49 |
+
sim_1 = (sim_1 * 100).softmax(dim=-1)
|
50 |
+
prefix_embedding_1 = sim_1 @ self.prompt.float()
|
51 |
+
prefix_embedding_1 /= prefix_embedding_1.norm(dim=-1, keepdim=True)
|
52 |
+
|
53 |
+
sim_2 = feature2 @ self.prompt.T.float()
|
54 |
+
sim_2 = (sim_2 * 100).softmax(dim=-1)
|
55 |
+
prefix_embedding_2 = sim_2 @ self.prompt.float()
|
56 |
+
prefix_embedding_2 /= prefix_embedding_2.norm(dim=-1, keepdim=True)
|
57 |
+
averaged_prompt_features = torch.mean(torch.stack([prefix_embedding_1, prefix_embedding_2]), dim=0)
|
58 |
+
return averaged_prompt_features
|
59 |
+
else:
|
60 |
+
# Concatenate the features
|
61 |
+
averaged_features = torch.mean(torch.stack([feature1, feature2]), dim=0)
|
62 |
+
return averaged_features
|
63 |
+
|
64 |
+
def align_encode_images_mimic_cxr(self, images):
|
65 |
+
feature = self.align_model.encode_image(images)
|
66 |
+
if self.args.prompt_load == 'yes':
|
67 |
+
sim = feature @ self.prompt.T.float()
|
68 |
+
sim = (sim * 100).softmax(dim=-1)
|
69 |
+
prefix_embedding = sim @ self.prompt.float()
|
70 |
+
prefix_embedding /= prefix_embedding.norm(dim=-1, keepdim=True)
|
71 |
+
return prefix_embedding
|
72 |
+
else:
|
73 |
+
return feature
|
74 |
+
|
75 |
+
def forward_iu_xray(self, reports_ids, align_ids, align_masks, images, mode='train', update_opts={}):
|
76 |
+
self.align_model.to(self.device)
|
77 |
+
self.align_model.eval()
|
78 |
+
align_ids = align_ids.long()
|
79 |
+
|
80 |
+
align_image_feature = None
|
81 |
+
if self.args.train_mode == 'fine-tuning':
|
82 |
+
align_image_feature = self.align_encode_images_iu_xray(images)
|
83 |
+
if mode == 'train':
|
84 |
+
align_text_feature = self.align_model.encode_text(align_ids, align_masks)
|
85 |
+
if self.args.noise_inject == 'yes':
|
86 |
+
align_text_feature = self.noise_injection(align_text_feature)
|
87 |
+
|
88 |
+
if self.args.train_mode == 'fine-tuning':
|
89 |
+
if self.args.F_version == 'v1':
|
90 |
+
combined_feature = torch.cat([align_text_feature, align_image_feature], dim=-1)
|
91 |
+
align_text_feature = self.fc_reduce_dim(combined_feature)
|
92 |
+
if self.args.F_version == 'v2':
|
93 |
+
align_text_feature = align_image_feature
|
94 |
+
|
95 |
+
outputs = self.model(align_text_feature, reports_ids, mode='forward')
|
96 |
+
logits = outputs.logits
|
97 |
+
logits = logits[:, :-1]
|
98 |
+
return logits
|
99 |
+
elif mode == 'sample':
|
100 |
+
align_image_feature = self.align_encode_images_iu_xray(images)
|
101 |
+
outputs = self.model(align_image_feature, reports_ids, mode='sample', update_opts=update_opts)
|
102 |
+
return outputs
|
103 |
+
else:
|
104 |
+
raise ValueError
|
105 |
+
|
106 |
+
def forward_mimic_cxr(self, reports_ids, align_ids, align_masks, images, mode='train', update_opts={}):
|
107 |
+
self.align_model.to(self.device)
|
108 |
+
self.align_model.eval()
|
109 |
+
align_ids = align_ids.long()
|
110 |
+
if mode == 'train':
|
111 |
+
if self.args.noise_inject == 'yes':
|
112 |
+
align_text_feature = self.align_model.encode_text(align_ids, align_masks)
|
113 |
+
align_text_feature = self.noise_injection(align_text_feature)
|
114 |
+
else:
|
115 |
+
align_text_feature = self.align_model.encode_text(align_ids, align_masks)
|
116 |
+
outputs = self.model(align_text_feature, reports_ids, mode='forward')
|
117 |
+
logits = outputs.logits
|
118 |
+
logits = logits[:, :-1]
|
119 |
+
return logits
|
120 |
+
elif mode == 'sample':
|
121 |
+
align_image_feature = self.align_encode_images_mimic_cxr(images)
|
122 |
+
outputs = self.model(align_image_feature, reports_ids, mode='sample', update_opts=update_opts)
|
123 |
+
return outputs
|
124 |
+
else:
|
125 |
+
raise ValueError
|
models/r2gen.py
ADDED
@@ -0,0 +1,63 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
import numpy as np
|
4 |
+
|
5 |
+
from modules.visual_extractor import VisualExtractor
|
6 |
+
from modules.encoder_decoder import EncoderDecoder
|
7 |
+
import torch.nn.functional as F
|
8 |
+
|
9 |
+
class R2GenModel(nn.Module):
|
10 |
+
def __init__(self, args, tokenizer):
|
11 |
+
super(R2GenModel, self).__init__()
|
12 |
+
self.args = args
|
13 |
+
self.tokenizer = tokenizer
|
14 |
+
self.visual_extractor = VisualExtractor(args)
|
15 |
+
self.encoder_decoder = EncoderDecoder(args, tokenizer)
|
16 |
+
if args.dataset_name == 'iu_xray':
|
17 |
+
self.forward = self.forward_iu_xray
|
18 |
+
else:
|
19 |
+
self.forward = self.forward_mimic_cxr
|
20 |
+
self.affine_a = nn.Linear(1024, 2048)
|
21 |
+
self.affine_b = nn.Linear(1024, 2048)
|
22 |
+
self.affine_c = nn.Linear(1024, 2048)
|
23 |
+
self.affine_d = nn.Linear(1024, 2048)
|
24 |
+
self.affine_aa = nn.Linear(1024, 2048)
|
25 |
+
self.affine_bb = nn.Linear(1024, 2048)
|
26 |
+
|
27 |
+
def __str__(self):
|
28 |
+
model_parameters = filter(lambda p: p.requires_grad, self.parameters())
|
29 |
+
params = sum([np.prod(p.size()) for p in model_parameters])
|
30 |
+
return super().__str__() + '\nTrainable parameters: {}'.format(params)
|
31 |
+
|
32 |
+
def forward_iu_xray(self, images, targets=None, mode='train'):
|
33 |
+
att_feats_0, fc_feats_0 = self.visual_extractor(images[:, 0])
|
34 |
+
att_feats_1, fc_feats_1 = self.visual_extractor(images[:, 1])
|
35 |
+
#new add
|
36 |
+
att_feats_0=F.relu(self.affine_a(att_feats_0))
|
37 |
+
fc_feats_0=F.relu(self.affine_b(fc_feats_0))
|
38 |
+
att_feats_1=F.relu(self.affine_c(att_feats_1))
|
39 |
+
fc_feats_1=F.relu(self.affine_d(fc_feats_1))
|
40 |
+
|
41 |
+
fc_feats = torch.cat((fc_feats_0, fc_feats_1), dim=1)
|
42 |
+
att_feats = torch.cat((att_feats_0, att_feats_1), dim=1)
|
43 |
+
if mode == 'train':
|
44 |
+
output = self.encoder_decoder(fc_feats, att_feats, targets, mode='forward')
|
45 |
+
elif mode == 'sample':
|
46 |
+
output, _ = self.encoder_decoder(fc_feats, att_feats, mode='sample')
|
47 |
+
else:
|
48 |
+
raise ValueError
|
49 |
+
return output
|
50 |
+
|
51 |
+
def forward_mimic_cxr(self, images, targets=None, mode='train'):
|
52 |
+
att_feats1, fc_feats1 = self.visual_extractor(images)
|
53 |
+
att_feats=F.relu(self.affine_aa(att_feats1))
|
54 |
+
fc_feats=F.relu(self.affine_bb(fc_feats1))
|
55 |
+
|
56 |
+
if mode == 'train':
|
57 |
+
output = self.encoder_decoder(fc_feats, att_feats, targets, mode='forward')
|
58 |
+
elif mode == 'sample':
|
59 |
+
output, _ = self.encoder_decoder(fc_feats, att_feats, mode='sample')
|
60 |
+
else:
|
61 |
+
raise ValueError
|
62 |
+
return output
|
63 |
+
|
modules/att_model.py
ADDED
@@ -0,0 +1,319 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from __future__ import absolute_import
|
2 |
+
from __future__ import division
|
3 |
+
from __future__ import print_function
|
4 |
+
|
5 |
+
import torch
|
6 |
+
import torch.nn as nn
|
7 |
+
import torch.nn.functional as F
|
8 |
+
from torch.nn.utils.rnn import PackedSequence, pack_padded_sequence, pad_packed_sequence
|
9 |
+
|
10 |
+
import modules.utils as utils
|
11 |
+
from modules.caption_model import CaptionModel
|
12 |
+
|
13 |
+
|
14 |
+
def sort_pack_padded_sequence(input, lengths):
|
15 |
+
sorted_lengths, indices = torch.sort(lengths, descending=True)
|
16 |
+
tmp = pack_padded_sequence(input[indices], sorted_lengths, batch_first=True)
|
17 |
+
inv_ix = indices.clone()
|
18 |
+
inv_ix[indices] = torch.arange(0, len(indices)).type_as(inv_ix)
|
19 |
+
return tmp, inv_ix
|
20 |
+
|
21 |
+
|
22 |
+
def pad_unsort_packed_sequence(input, inv_ix):
|
23 |
+
tmp, _ = pad_packed_sequence(input, batch_first=True)
|
24 |
+
tmp = tmp[inv_ix]
|
25 |
+
return tmp
|
26 |
+
|
27 |
+
|
28 |
+
def pack_wrapper(module, att_feats, att_masks):
|
29 |
+
if att_masks is not None:
|
30 |
+
packed, inv_ix = sort_pack_padded_sequence(att_feats, att_masks.data.long().sum(1))
|
31 |
+
return pad_unsort_packed_sequence(PackedSequence(module(packed[0]), packed[1]), inv_ix)
|
32 |
+
else:
|
33 |
+
return module(att_feats)
|
34 |
+
|
35 |
+
|
36 |
+
class AttModel(CaptionModel):
|
37 |
+
def __init__(self, args, tokenizer):
|
38 |
+
super(AttModel, self).__init__()
|
39 |
+
self.args = args
|
40 |
+
self.tokenizer = tokenizer
|
41 |
+
self.vocab_size = len(tokenizer.idx2token)
|
42 |
+
self.input_encoding_size = args.d_model
|
43 |
+
self.rnn_size = args.d_ff
|
44 |
+
self.num_layers = args.num_layers
|
45 |
+
self.drop_prob_lm = args.drop_prob_lm
|
46 |
+
self.max_seq_length = args.max_seq_length
|
47 |
+
self.att_feat_size = args.d_vf
|
48 |
+
self.att_hid_size = args.d_model
|
49 |
+
|
50 |
+
self.bos_idx = args.bos_idx
|
51 |
+
self.eos_idx = args.eos_idx
|
52 |
+
self.pad_idx = args.pad_idx
|
53 |
+
|
54 |
+
self.use_bn = args.use_bn
|
55 |
+
|
56 |
+
self.embed = lambda x: x
|
57 |
+
self.fc_embed = lambda x: x
|
58 |
+
self.att_embed = nn.Sequential(*(
|
59 |
+
((nn.BatchNorm1d(self.att_feat_size),) if self.use_bn else ()) +
|
60 |
+
(nn.Linear(self.att_feat_size, self.input_encoding_size),
|
61 |
+
nn.ReLU(),
|
62 |
+
nn.Dropout(self.drop_prob_lm)) +
|
63 |
+
((nn.BatchNorm1d(self.input_encoding_size),) if self.use_bn == 2 else ())))
|
64 |
+
|
65 |
+
def clip_att(self, att_feats, att_masks):
|
66 |
+
# Clip the length of att_masks and att_feats to the maximum length
|
67 |
+
if att_masks is not None:
|
68 |
+
max_len = att_masks.data.long().sum(1).max()
|
69 |
+
att_feats = att_feats[:, :max_len].contiguous()
|
70 |
+
att_masks = att_masks[:, :max_len].contiguous()
|
71 |
+
return att_feats, att_masks
|
72 |
+
|
73 |
+
def _prepare_feature(self, fc_feats, att_feats, att_masks):
|
74 |
+
att_feats, att_masks = self.clip_att(att_feats, att_masks)
|
75 |
+
|
76 |
+
# embed fc and att feats
|
77 |
+
fc_feats = self.fc_embed(fc_feats)
|
78 |
+
att_feats = pack_wrapper(self.att_embed, att_feats, att_masks)
|
79 |
+
|
80 |
+
# Project the attention feats first to reduce memory and computation comsumptions.
|
81 |
+
p_att_feats = self.ctx2att(att_feats)
|
82 |
+
|
83 |
+
return fc_feats, att_feats, p_att_feats, att_masks
|
84 |
+
|
85 |
+
def get_logprobs_state(self, it, fc_feats, att_feats, p_att_feats, att_masks, state, output_logsoftmax=1):
|
86 |
+
# 'it' contains a word index
|
87 |
+
xt = self.embed(it)
|
88 |
+
|
89 |
+
output, state = self.core(xt, fc_feats, att_feats, p_att_feats, state, att_masks)
|
90 |
+
if output_logsoftmax:
|
91 |
+
logprobs = F.log_softmax(self.logit(output), dim=1)
|
92 |
+
else:
|
93 |
+
logprobs = self.logit(output)
|
94 |
+
|
95 |
+
return logprobs, state
|
96 |
+
|
97 |
+
def _sample_beam(self, fc_feats, att_feats, att_masks=None, opt={}):
|
98 |
+
beam_size = opt.get('beam_size', 10)
|
99 |
+
group_size = opt.get('group_size', 1)
|
100 |
+
sample_n = opt.get('sample_n', 10)
|
101 |
+
# when sample_n == beam_size then each beam is a sample.
|
102 |
+
assert sample_n == 1 or sample_n == beam_size // group_size, 'when beam search, sample_n == 1 or beam search'
|
103 |
+
batch_size = fc_feats.size(0)
|
104 |
+
|
105 |
+
p_fc_feats, p_att_feats, pp_att_feats, p_att_masks = self._prepare_feature(fc_feats, att_feats, att_masks)
|
106 |
+
|
107 |
+
assert beam_size <= self.vocab_size + 1, 'lets assume this for now, otherwise this corner case causes a few headaches down the road. can be dealt with in future if needed'
|
108 |
+
seq = fc_feats.new_full((batch_size * sample_n, self.max_seq_length), self.pad_idx, dtype=torch.long)
|
109 |
+
seqLogprobs = fc_feats.new_zeros(batch_size * sample_n, self.max_seq_length, self.vocab_size + 1)
|
110 |
+
# lets process every image independently for now, for simplicity
|
111 |
+
|
112 |
+
self.done_beams = [[] for _ in range(batch_size)]
|
113 |
+
|
114 |
+
state = self.init_hidden(batch_size)
|
115 |
+
|
116 |
+
# first step, feed bos
|
117 |
+
it = fc_feats.new_full([batch_size], self.bos_idx, dtype=torch.long)
|
118 |
+
logprobs, state = self.get_logprobs_state(it, p_fc_feats, p_att_feats, pp_att_feats, p_att_masks, state)
|
119 |
+
|
120 |
+
p_fc_feats, p_att_feats, pp_att_feats, p_att_masks = utils.repeat_tensors(beam_size,
|
121 |
+
[p_fc_feats, p_att_feats,
|
122 |
+
pp_att_feats, p_att_masks]
|
123 |
+
)
|
124 |
+
self.done_beams = self.beam_search(state, logprobs, p_fc_feats, p_att_feats, pp_att_feats, p_att_masks, opt=opt)
|
125 |
+
for k in range(batch_size):
|
126 |
+
if sample_n == beam_size:
|
127 |
+
for _n in range(sample_n):
|
128 |
+
seq_len = self.done_beams[k][_n]['seq'].shape[0]
|
129 |
+
seq[k * sample_n + _n, :seq_len] = self.done_beams[k][_n]['seq']
|
130 |
+
seqLogprobs[k * sample_n + _n, :seq_len] = self.done_beams[k][_n]['logps']
|
131 |
+
else:
|
132 |
+
seq_len = self.done_beams[k][0]['seq'].shape[0]
|
133 |
+
seq[k, :seq_len] = self.done_beams[k][0]['seq'] # the first beam has highest cumulative score
|
134 |
+
seqLogprobs[k, :seq_len] = self.done_beams[k][0]['logps']
|
135 |
+
# return the samples and their log likelihoods
|
136 |
+
return seq, seqLogprobs
|
137 |
+
|
138 |
+
def _sample(self, fc_feats, att_feats, att_masks=None):
|
139 |
+
opt = self.args.__dict__
|
140 |
+
sample_method = opt.get('sample_method', 'greedy')
|
141 |
+
beam_size = opt.get('beam_size', 1)
|
142 |
+
temperature = opt.get('temperature', 1.0)
|
143 |
+
sample_n = int(opt.get('sample_n', 1))
|
144 |
+
group_size = opt.get('group_size', 1)
|
145 |
+
output_logsoftmax = opt.get('output_logsoftmax', 1)
|
146 |
+
decoding_constraint = opt.get('decoding_constraint', 0)
|
147 |
+
block_trigrams = opt.get('block_trigrams', 0)
|
148 |
+
if beam_size > 1 and sample_method in ['greedy', 'beam_search']:
|
149 |
+
return self._sample_beam(fc_feats, att_feats, att_masks, opt)
|
150 |
+
if group_size > 1:
|
151 |
+
return self._diverse_sample(fc_feats, att_feats, att_masks, opt)
|
152 |
+
|
153 |
+
batch_size = fc_feats.size(0)
|
154 |
+
state = self.init_hidden(batch_size * sample_n)
|
155 |
+
|
156 |
+
p_fc_feats, p_att_feats, pp_att_feats, p_att_masks = self._prepare_feature(fc_feats, att_feats, att_masks)
|
157 |
+
|
158 |
+
if sample_n > 1:
|
159 |
+
p_fc_feats, p_att_feats, pp_att_feats, p_att_masks = utils.repeat_tensors(sample_n,
|
160 |
+
[p_fc_feats, p_att_feats,
|
161 |
+
pp_att_feats, p_att_masks]
|
162 |
+
)
|
163 |
+
|
164 |
+
trigrams = [] # will be a list of batch_size dictionaries
|
165 |
+
|
166 |
+
seq = fc_feats.new_full((batch_size * sample_n, self.max_seq_length), self.pad_idx, dtype=torch.long)
|
167 |
+
seqLogprobs = fc_feats.new_zeros(batch_size * sample_n, self.max_seq_length, self.vocab_size + 1)
|
168 |
+
for t in range(self.max_seq_length + 1):
|
169 |
+
if t == 0: # input <bos>
|
170 |
+
it = fc_feats.new_full([batch_size * sample_n], self.bos_idx, dtype=torch.long)
|
171 |
+
|
172 |
+
logprobs, state = self.get_logprobs_state(it, p_fc_feats, p_att_feats, pp_att_feats, p_att_masks, state,
|
173 |
+
output_logsoftmax=output_logsoftmax)
|
174 |
+
|
175 |
+
if decoding_constraint and t > 0:
|
176 |
+
tmp = logprobs.new_zeros(logprobs.size())
|
177 |
+
tmp.scatter_(1, seq[:, t - 1].data.unsqueeze(1), float('-inf'))
|
178 |
+
logprobs = logprobs + tmp
|
179 |
+
|
180 |
+
# Mess with trigrams
|
181 |
+
# Copy from https://github.com/lukemelas/image-paragraph-captioning
|
182 |
+
if block_trigrams and t >= 3:
|
183 |
+
# Store trigram generated at last step
|
184 |
+
prev_two_batch = seq[:, t - 3:t - 1]
|
185 |
+
for i in range(batch_size): # = seq.size(0)
|
186 |
+
prev_two = (prev_two_batch[i][0].item(), prev_two_batch[i][1].item())
|
187 |
+
current = seq[i][t - 1]
|
188 |
+
if t == 3: # initialize
|
189 |
+
trigrams.append({prev_two: [current]}) # {LongTensor: list containing 1 int}
|
190 |
+
elif t > 3:
|
191 |
+
if prev_two in trigrams[i]: # add to list
|
192 |
+
trigrams[i][prev_two].append(current)
|
193 |
+
else: # create list
|
194 |
+
trigrams[i][prev_two] = [current]
|
195 |
+
# Block used trigrams at next step
|
196 |
+
prev_two_batch = seq[:, t - 2:t]
|
197 |
+
mask = torch.zeros(logprobs.size(), requires_grad=False).cuda() # batch_size x vocab_size
|
198 |
+
for i in range(batch_size):
|
199 |
+
prev_two = (prev_two_batch[i][0].item(), prev_two_batch[i][1].item())
|
200 |
+
if prev_two in trigrams[i]:
|
201 |
+
for j in trigrams[i][prev_two]:
|
202 |
+
mask[i, j] += 1
|
203 |
+
# Apply mask to log probs
|
204 |
+
# logprobs = logprobs - (mask * 1e9)
|
205 |
+
alpha = 2.0 # = 4
|
206 |
+
logprobs = logprobs + (mask * -0.693 * alpha) # ln(1/2) * alpha (alpha -> infty works best)
|
207 |
+
|
208 |
+
# sample the next word
|
209 |
+
if t == self.max_seq_length: # skip if we achieve maximum length
|
210 |
+
break
|
211 |
+
it, sampleLogprobs = self.sample_next_word(logprobs, sample_method, temperature)
|
212 |
+
|
213 |
+
# stop when all finished
|
214 |
+
if t == 0:
|
215 |
+
unfinished = it != self.eos_idx
|
216 |
+
else:
|
217 |
+
it[~unfinished] = self.pad_idx # This allows eos_idx not being overwritten to 0
|
218 |
+
logprobs = logprobs * unfinished.unsqueeze(1).float()
|
219 |
+
unfinished = unfinished * (it != self.eos_idx)
|
220 |
+
seq[:, t] = it
|
221 |
+
seqLogprobs[:, t] = logprobs
|
222 |
+
# quit loop if all sequences have finished
|
223 |
+
if unfinished.sum() == 0:
|
224 |
+
break
|
225 |
+
|
226 |
+
return seq, seqLogprobs
|
227 |
+
|
228 |
+
def _diverse_sample(self, fc_feats, att_feats, att_masks=None, opt={}):
|
229 |
+
|
230 |
+
sample_method = opt.get('sample_method', 'greedy')
|
231 |
+
beam_size = opt.get('beam_size', 1)
|
232 |
+
temperature = opt.get('temperature', 1.0)
|
233 |
+
group_size = opt.get('group_size', 1)
|
234 |
+
diversity_lambda = opt.get('diversity_lambda', 0.5)
|
235 |
+
decoding_constraint = opt.get('decoding_constraint', 0)
|
236 |
+
block_trigrams = opt.get('block_trigrams', 0)
|
237 |
+
|
238 |
+
batch_size = fc_feats.size(0)
|
239 |
+
state = self.init_hidden(batch_size)
|
240 |
+
|
241 |
+
p_fc_feats, p_att_feats, pp_att_feats, p_att_masks = self._prepare_feature(fc_feats, att_feats, att_masks)
|
242 |
+
|
243 |
+
trigrams_table = [[] for _ in range(group_size)] # will be a list of batch_size dictionaries
|
244 |
+
|
245 |
+
seq_table = [fc_feats.new_full((batch_size, self.max_seq_length), self.pad_idx, dtype=torch.long) for _ in
|
246 |
+
range(group_size)]
|
247 |
+
seqLogprobs_table = [fc_feats.new_zeros(batch_size, self.max_seq_length) for _ in range(group_size)]
|
248 |
+
state_table = [self.init_hidden(batch_size) for _ in range(group_size)]
|
249 |
+
|
250 |
+
for tt in range(self.max_seq_length + group_size):
|
251 |
+
for divm in range(group_size):
|
252 |
+
t = tt - divm
|
253 |
+
seq = seq_table[divm]
|
254 |
+
seqLogprobs = seqLogprobs_table[divm]
|
255 |
+
trigrams = trigrams_table[divm]
|
256 |
+
if t >= 0 and t <= self.max_seq_length - 1:
|
257 |
+
if t == 0: # input <bos>
|
258 |
+
it = fc_feats.new_full([batch_size], self.bos_idx, dtype=torch.long)
|
259 |
+
else:
|
260 |
+
it = seq[:, t - 1] # changed
|
261 |
+
|
262 |
+
logprobs, state_table[divm] = self.get_logprobs_state(it, p_fc_feats, p_att_feats, pp_att_feats,
|
263 |
+
p_att_masks, state_table[divm]) # changed
|
264 |
+
logprobs = F.log_softmax(logprobs / temperature, dim=-1)
|
265 |
+
|
266 |
+
# Add diversity
|
267 |
+
if divm > 0:
|
268 |
+
unaug_logprobs = logprobs.clone()
|
269 |
+
for prev_choice in range(divm):
|
270 |
+
prev_decisions = seq_table[prev_choice][:, t]
|
271 |
+
logprobs[:, prev_decisions] = logprobs[:, prev_decisions] - diversity_lambda
|
272 |
+
|
273 |
+
if decoding_constraint and t > 0:
|
274 |
+
tmp = logprobs.new_zeros(logprobs.size())
|
275 |
+
tmp.scatter_(1, seq[:, t - 1].data.unsqueeze(1), float('-inf'))
|
276 |
+
logprobs = logprobs + tmp
|
277 |
+
|
278 |
+
# Mess with trigrams
|
279 |
+
if block_trigrams and t >= 3:
|
280 |
+
# Store trigram generated at last step
|
281 |
+
prev_two_batch = seq[:, t - 3:t - 1]
|
282 |
+
for i in range(batch_size): # = seq.size(0)
|
283 |
+
prev_two = (prev_two_batch[i][0].item(), prev_two_batch[i][1].item())
|
284 |
+
current = seq[i][t - 1]
|
285 |
+
if t == 3: # initialize
|
286 |
+
trigrams.append({prev_two: [current]}) # {LongTensor: list containing 1 int}
|
287 |
+
elif t > 3:
|
288 |
+
if prev_two in trigrams[i]: # add to list
|
289 |
+
trigrams[i][prev_two].append(current)
|
290 |
+
else: # create list
|
291 |
+
trigrams[i][prev_two] = [current]
|
292 |
+
# Block used trigrams at next step
|
293 |
+
prev_two_batch = seq[:, t - 2:t]
|
294 |
+
mask = torch.zeros(logprobs.size(), requires_grad=False).cuda() # batch_size x vocab_size
|
295 |
+
for i in range(batch_size):
|
296 |
+
prev_two = (prev_two_batch[i][0].item(), prev_two_batch[i][1].item())
|
297 |
+
if prev_two in trigrams[i]:
|
298 |
+
for j in trigrams[i][prev_two]:
|
299 |
+
mask[i, j] += 1
|
300 |
+
# Apply mask to log probs
|
301 |
+
# logprobs = logprobs - (mask * 1e9)
|
302 |
+
alpha = 2.0 # = 4
|
303 |
+
logprobs = logprobs + (mask * -0.693 * alpha) # ln(1/2) * alpha (alpha -> infty works best)
|
304 |
+
|
305 |
+
it, sampleLogprobs = self.sample_next_word(logprobs, sample_method, 1)
|
306 |
+
|
307 |
+
# stop when all finished
|
308 |
+
if t == 0:
|
309 |
+
unfinished = it != self.eos_idx
|
310 |
+
else:
|
311 |
+
unfinished = seq[:, t - 1] != self.pad_idx & seq[:, t - 1] != self.eos_idx
|
312 |
+
it[~unfinished] = self.pad_idx
|
313 |
+
unfinished = unfinished & (it != self.eos_idx) # changed
|
314 |
+
seq[:, t] = it
|
315 |
+
seqLogprobs[:, t] = sampleLogprobs.view(-1)
|
316 |
+
|
317 |
+
return torch.stack(seq_table, 1).reshape(batch_size * group_size, -1), torch.stack(seqLogprobs_table,
|
318 |
+
1).reshape(
|
319 |
+
batch_size * group_size, -1)
|
modules/att_models.py
ADDED
@@ -0,0 +1,120 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from __future__ import absolute_import
|
2 |
+
from __future__ import division
|
3 |
+
from __future__ import print_function
|
4 |
+
|
5 |
+
import pdb
|
6 |
+
|
7 |
+
import torch
|
8 |
+
import torch.nn as nn
|
9 |
+
import torch.nn.functional as F
|
10 |
+
from torch.nn.utils.rnn import PackedSequence, pack_padded_sequence, pad_packed_sequence
|
11 |
+
|
12 |
+
import modules.utils as utils
|
13 |
+
from modules.caption_model import CaptionModel
|
14 |
+
|
15 |
+
|
16 |
+
class AttModel(CaptionModel):
|
17 |
+
def __init__(self, args, tokenizer):
|
18 |
+
super(AttModel, self).__init__()
|
19 |
+
self.args = args
|
20 |
+
self.tokenizer = tokenizer
|
21 |
+
self.vocab_size = len(tokenizer.idx2token)
|
22 |
+
self.max_seq_length = 60
|
23 |
+
|
24 |
+
def _sample(self, clip_features, gpt_tokens,update_opts={}):
|
25 |
+
|
26 |
+
opt = self.args.__dict__
|
27 |
+
opt.update(**update_opts)
|
28 |
+
sample_method = opt.get('sample_method', 'greedy')
|
29 |
+
|
30 |
+
|
31 |
+
if sample_method == 'greedy':
|
32 |
+
return self._greedy_sample(clip_features, gpt_tokens)
|
33 |
+
elif sample_method == 'beam_search':
|
34 |
+
return self._beam_search_sample(clip_features, gpt_tokens)
|
35 |
+
else:
|
36 |
+
raise ValueError("Unknown sample_method: " + sample_method)
|
37 |
+
|
38 |
+
def _greedy_sample(self, clip_features, gpt_tokens, temperature=1.0):
|
39 |
+
#input_ids = torch.full((clip_features.size(0), 1), self.tokenizer.bos_token_id).type_as(clip_features).long()
|
40 |
+
clip_features = self.clip_project(clip_features).reshape(clip_features.size(0), 1, -1)
|
41 |
+
tokens = [None for _ in range(clip_features.size(0))]
|
42 |
+
finished = [False for _ in range(clip_features.size(0))]
|
43 |
+
max_length = 200
|
44 |
+
for _ in range(max_length):
|
45 |
+
outputs = self.decoder(inputs_embeds= clip_features)
|
46 |
+
logits = outputs.logits[:, -1, :] / (temperature if temperature > 0 else 1.0)
|
47 |
+
next_tokens = torch.argmax(logits, -1).unsqueeze(1)
|
48 |
+
next_token_embeds = self.decoder.transformer.wte(next_tokens)
|
49 |
+
for j in range(clip_features.size(0)):
|
50 |
+
if finished[j]:
|
51 |
+
continue
|
52 |
+
if tokens[j] is None:
|
53 |
+
tokens[j] = next_tokens[j]
|
54 |
+
else:
|
55 |
+
tokens[j] = torch.cat((tokens[j], next_tokens[j]), dim=0)
|
56 |
+
if next_tokens[j].item() == self.tokenizer.eos_token_id:
|
57 |
+
finished[j] = True
|
58 |
+
clip_features = torch.cat((clip_features, next_token_embeds), dim=1)
|
59 |
+
outputs = []
|
60 |
+
for token in tokens:
|
61 |
+
try:
|
62 |
+
output_list = token.squeeze().cpu().numpy().tolist()
|
63 |
+
# Pad or truncate output_list to max_length
|
64 |
+
output_list = (output_list + [self.tokenizer.pad_token_id] * max_length)[:max_length]
|
65 |
+
except Exception as e:
|
66 |
+
print(f"Error during decoding: {type(e).__name__}: {e}")
|
67 |
+
output_list = [self.tokenizer.pad_token_id] * max_length
|
68 |
+
outputs.append(output_list)
|
69 |
+
|
70 |
+
# Convert list of lists to tensor
|
71 |
+
outputs = torch.tensor(outputs, device=clip_features.device)
|
72 |
+
return outputs
|
73 |
+
|
74 |
+
|
75 |
+
def _beam_search_sample(self, clip_features, gpt_tokens, beam_size=5):
|
76 |
+
batch_size = clip_features.size(0)
|
77 |
+
# Prepare the first input for every beam
|
78 |
+
input_ids = torch.full((batch_size*beam_size, 1), self.tokenizer.bos_token_id).type_as(clip_features).long()
|
79 |
+
beam_scores = torch.zeros((batch_size, beam_size)).type_as(clip_features)
|
80 |
+
done = [False]*batch_size
|
81 |
+
|
82 |
+
for _ in range(self.max_seq_length):
|
83 |
+
outputs = self._forward(clip_features.repeat_interleave(beam_size, 0), input_ids)
|
84 |
+
next_token_logits = outputs.logits[:, -1, :]
|
85 |
+
next_token_probs = F.softmax(next_token_logits, dim=-1)
|
86 |
+
|
87 |
+
# Apply a mask for already finished beams
|
88 |
+
next_token_probs[done] = 0
|
89 |
+
next_token_probs[:, self.tokenizer.eos_token_id] = -float('Inf')
|
90 |
+
|
91 |
+
# Multiply old scores with new probabilities
|
92 |
+
scores = beam_scores.unsqueeze(2) * next_token_probs
|
93 |
+
scores = scores.view(batch_size, -1)
|
94 |
+
|
95 |
+
# Get the top beam_size scores and their respective indices
|
96 |
+
top_scores, top_indices = scores.topk(beam_size, dim=1)
|
97 |
+
|
98 |
+
# Update beam scores
|
99 |
+
beam_scores = top_scores.log()
|
100 |
+
|
101 |
+
# Reshape input_ids
|
102 |
+
input_ids = input_ids.view(batch_size, beam_size, -1)
|
103 |
+
|
104 |
+
# Compute next inputs
|
105 |
+
next_token_ids = top_indices % self.vocab_size
|
106 |
+
beam_indices = top_indices // self.vocab_size
|
107 |
+
next_input_ids = torch.cat([input_ids.gather(1, beam_indices.unsqueeze(2).expand(-1, -1, input_ids.size(2))), next_token_ids.unsqueeze(2)], dim=2)
|
108 |
+
|
109 |
+
# Flatten input_ids
|
110 |
+
input_ids = next_input_ids.view(batch_size*beam_size, -1)
|
111 |
+
|
112 |
+
# Check which beams are done
|
113 |
+
done = (next_token_ids == self.tokenizer.eos_token_id).all(dim=1).tolist()
|
114 |
+
|
115 |
+
if all(done):
|
116 |
+
break
|
117 |
+
|
118 |
+
return input_ids.view(batch_size, beam_size, -1)
|
119 |
+
|
120 |
+
|
modules/caption_model.py
ADDED
@@ -0,0 +1,401 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from __future__ import absolute_import
|
2 |
+
from __future__ import division
|
3 |
+
from __future__ import print_function
|
4 |
+
|
5 |
+
import torch
|
6 |
+
import torch.nn as nn
|
7 |
+
import torch.nn.functional as F
|
8 |
+
|
9 |
+
import modules.utils as utils
|
10 |
+
|
11 |
+
|
12 |
+
class CaptionModel(nn.Module):
|
13 |
+
def __init__(self):
|
14 |
+
super(CaptionModel, self).__init__()
|
15 |
+
|
16 |
+
# implements beam search
|
17 |
+
# calls beam_step and returns the final set of beams
|
18 |
+
# augments log-probabilities with diversity terms when number of groups > 1
|
19 |
+
|
20 |
+
def forward(self, *args, **kwargs):
|
21 |
+
mode = kwargs.get('mode', 'forward')
|
22 |
+
if 'mode' in kwargs:
|
23 |
+
del kwargs['mode']
|
24 |
+
return getattr(self, '_' + mode)(*args, **kwargs)
|
25 |
+
|
26 |
+
def beam_search(self, init_state, init_logprobs, *args, **kwargs):
|
27 |
+
|
28 |
+
# function computes the similarity score to be augmented
|
29 |
+
def add_diversity(beam_seq_table, logprobs, t, divm, diversity_lambda, bdash):
|
30 |
+
local_time = t - divm
|
31 |
+
unaug_logprobs = logprobs.clone()
|
32 |
+
batch_size = beam_seq_table[0].shape[0]
|
33 |
+
|
34 |
+
if divm > 0:
|
35 |
+
change = logprobs.new_zeros(batch_size, logprobs.shape[-1])
|
36 |
+
for prev_choice in range(divm):
|
37 |
+
prev_decisions = beam_seq_table[prev_choice][:, :, local_time] # Nxb
|
38 |
+
for prev_labels in range(bdash):
|
39 |
+
change.scatter_add_(1, prev_decisions[:, prev_labels].unsqueeze(-1),
|
40 |
+
change.new_ones(batch_size, 1))
|
41 |
+
|
42 |
+
if local_time == 0:
|
43 |
+
logprobs = logprobs - change * diversity_lambda
|
44 |
+
else:
|
45 |
+
logprobs = logprobs - self.repeat_tensor(bdash, change) * diversity_lambda
|
46 |
+
|
47 |
+
return logprobs, unaug_logprobs
|
48 |
+
|
49 |
+
# does one step of classical beam search
|
50 |
+
|
51 |
+
def beam_step(logprobs, unaug_logprobs, beam_size, t, beam_seq, beam_seq_logprobs, beam_logprobs_sum, state):
|
52 |
+
# INPUTS:
|
53 |
+
# logprobs: probabilities augmented after diversity N*bxV
|
54 |
+
# beam_size: obvious
|
55 |
+
# t : time instant
|
56 |
+
# beam_seq : tensor contanining the beams
|
57 |
+
# beam_seq_logprobs: tensor contanining the beam logprobs
|
58 |
+
# beam_logprobs_sum: tensor contanining joint logprobs
|
59 |
+
# OUPUTS:
|
60 |
+
# beam_seq : tensor containing the word indices of the decoded captions Nxbxl
|
61 |
+
# beam_seq_logprobs : log-probability of each decision made, NxbxlxV
|
62 |
+
# beam_logprobs_sum : joint log-probability of each beam Nxb
|
63 |
+
|
64 |
+
batch_size = beam_logprobs_sum.shape[0]
|
65 |
+
vocab_size = logprobs.shape[-1]
|
66 |
+
logprobs = logprobs.reshape(batch_size, -1, vocab_size) # NxbxV
|
67 |
+
if t == 0:
|
68 |
+
assert logprobs.shape[1] == 1
|
69 |
+
beam_logprobs_sum = beam_logprobs_sum[:, :1]
|
70 |
+
candidate_logprobs = beam_logprobs_sum.unsqueeze(-1) + logprobs # beam_logprobs_sum Nxb logprobs is NxbxV
|
71 |
+
ys, ix = torch.sort(candidate_logprobs.reshape(candidate_logprobs.shape[0], -1), -1, True)
|
72 |
+
ys, ix = ys[:, :beam_size], ix[:, :beam_size]
|
73 |
+
beam_ix = ix // vocab_size # Nxb which beam
|
74 |
+
selected_ix = ix % vocab_size # Nxb # which world
|
75 |
+
state_ix = (beam_ix + torch.arange(batch_size).type_as(beam_ix).unsqueeze(-1) * logprobs.shape[1]).reshape(
|
76 |
+
-1) # N*b which in Nxb beams
|
77 |
+
|
78 |
+
if t > 0:
|
79 |
+
# gather according to beam_ix
|
80 |
+
assert (beam_seq.gather(1, beam_ix.unsqueeze(-1).expand_as(beam_seq)) ==
|
81 |
+
beam_seq.reshape(-1, beam_seq.shape[-1])[state_ix].view_as(beam_seq)).all()
|
82 |
+
beam_seq = beam_seq.gather(1, beam_ix.unsqueeze(-1).expand_as(beam_seq))
|
83 |
+
|
84 |
+
beam_seq_logprobs = beam_seq_logprobs.gather(1, beam_ix.unsqueeze(-1).unsqueeze(-1).expand_as(
|
85 |
+
beam_seq_logprobs))
|
86 |
+
|
87 |
+
beam_seq = torch.cat([beam_seq, selected_ix.unsqueeze(-1)], -1) # beam_seq Nxbxl
|
88 |
+
beam_logprobs_sum = beam_logprobs_sum.gather(1, beam_ix) + \
|
89 |
+
logprobs.reshape(batch_size, -1).gather(1, ix)
|
90 |
+
assert (beam_logprobs_sum == ys).all()
|
91 |
+
_tmp_beam_logprobs = unaug_logprobs[state_ix].reshape(batch_size, -1, vocab_size)
|
92 |
+
beam_logprobs = unaug_logprobs.reshape(batch_size, -1, vocab_size).gather(1,
|
93 |
+
beam_ix.unsqueeze(-1).expand(-1,
|
94 |
+
-1,
|
95 |
+
vocab_size)) # NxbxV
|
96 |
+
assert (_tmp_beam_logprobs == beam_logprobs).all()
|
97 |
+
beam_seq_logprobs = torch.cat([
|
98 |
+
beam_seq_logprobs,
|
99 |
+
beam_logprobs.reshape(batch_size, -1, 1, vocab_size)], 2)
|
100 |
+
|
101 |
+
new_state = [None for _ in state]
|
102 |
+
for _ix in range(len(new_state)):
|
103 |
+
# copy over state in previous beam q to new beam at vix
|
104 |
+
new_state[_ix] = state[_ix][:, state_ix]
|
105 |
+
state = new_state
|
106 |
+
return beam_seq, beam_seq_logprobs, beam_logprobs_sum, state
|
107 |
+
|
108 |
+
# Start diverse_beam_search
|
109 |
+
opt = kwargs['opt']
|
110 |
+
temperature = opt.get('temperature', 1) # This should not affect beam search, but will affect dbs
|
111 |
+
beam_size = opt.get('beam_size', 10)
|
112 |
+
group_size = opt.get('group_size', 1)
|
113 |
+
diversity_lambda = opt.get('diversity_lambda', 0.5)
|
114 |
+
decoding_constraint = opt.get('decoding_constraint', 0)
|
115 |
+
suppress_UNK = opt.get('suppress_UNK', 0)
|
116 |
+
length_penalty = utils.penalty_builder(opt.get('length_penalty', ''))
|
117 |
+
bdash = beam_size // group_size # beam per group
|
118 |
+
|
119 |
+
batch_size = init_logprobs.shape[0]
|
120 |
+
device = init_logprobs.device
|
121 |
+
# INITIALIZATIONS
|
122 |
+
beam_seq_table = [torch.LongTensor(batch_size, bdash, 0).to(device) for _ in range(group_size)]
|
123 |
+
beam_seq_logprobs_table = [torch.FloatTensor(batch_size, bdash, 0, self.vocab_size + 1).to(device) for _ in
|
124 |
+
range(group_size)]
|
125 |
+
beam_logprobs_sum_table = [torch.zeros(batch_size, bdash).to(device) for _ in range(group_size)]
|
126 |
+
|
127 |
+
# logprobs # logprobs predicted in last time step, shape (beam_size, vocab_size+1)
|
128 |
+
done_beams_table = [[[] for __ in range(group_size)] for _ in range(batch_size)]
|
129 |
+
state_table = [[_.clone() for _ in init_state] for _ in range(group_size)]
|
130 |
+
logprobs_table = [init_logprobs.clone() for _ in range(group_size)]
|
131 |
+
# END INIT
|
132 |
+
|
133 |
+
# Chunk elements in the args
|
134 |
+
args = list(args)
|
135 |
+
args = utils.split_tensors(group_size, args) # For each arg, turn (Bbg)x... to (Bb)x(g)x...
|
136 |
+
if self.__class__.__name__ == 'AttEnsemble':
|
137 |
+
args = [[[args[j][i][k] for i in range(len(self.models))] for j in range(len(args))] for k in
|
138 |
+
range(group_size)] # group_name, arg_name, model_name
|
139 |
+
else:
|
140 |
+
args = [[args[i][j] for i in range(len(args))] for j in range(group_size)]
|
141 |
+
|
142 |
+
for t in range(self.max_seq_length + group_size - 1):
|
143 |
+
for divm in range(group_size):
|
144 |
+
if t >= divm and t <= self.max_seq_length + divm - 1:
|
145 |
+
# add diversity
|
146 |
+
logprobs = logprobs_table[divm]
|
147 |
+
# suppress previous word
|
148 |
+
if decoding_constraint and t - divm > 0:
|
149 |
+
logprobs.scatter_(1, beam_seq_table[divm][:, :, t - divm - 1].reshape(-1, 1).to(device),
|
150 |
+
float('-inf'))
|
151 |
+
# suppress UNK tokens in the decoding
|
152 |
+
if suppress_UNK and hasattr(self, 'vocab') and self.vocab[str(logprobs.size(1) - 1)] == 'UNK':
|
153 |
+
logprobs[:, logprobs.size(1) - 1] = logprobs[:, logprobs.size(1) - 1] - 1000
|
154 |
+
# diversity is added here
|
155 |
+
# the function directly modifies the logprobs values and hence, we need to return
|
156 |
+
# the unaugmented ones for sorting the candidates in the end. # for historical
|
157 |
+
# reasons :-)
|
158 |
+
logprobs, unaug_logprobs = add_diversity(beam_seq_table, logprobs, t, divm, diversity_lambda, bdash)
|
159 |
+
|
160 |
+
# infer new beams
|
161 |
+
beam_seq_table[divm], \
|
162 |
+
beam_seq_logprobs_table[divm], \
|
163 |
+
beam_logprobs_sum_table[divm], \
|
164 |
+
state_table[divm] = beam_step(logprobs,
|
165 |
+
unaug_logprobs,
|
166 |
+
bdash,
|
167 |
+
t - divm,
|
168 |
+
beam_seq_table[divm],
|
169 |
+
beam_seq_logprobs_table[divm],
|
170 |
+
beam_logprobs_sum_table[divm],
|
171 |
+
state_table[divm])
|
172 |
+
|
173 |
+
# if time's up... or if end token is reached then copy beams
|
174 |
+
for b in range(batch_size):
|
175 |
+
is_end = beam_seq_table[divm][b, :, t - divm] == self.eos_idx
|
176 |
+
assert beam_seq_table[divm].shape[-1] == t - divm + 1
|
177 |
+
if t == self.max_seq_length + divm - 1:
|
178 |
+
is_end.fill_(1)
|
179 |
+
for vix in range(bdash):
|
180 |
+
if is_end[vix]:
|
181 |
+
final_beam = {
|
182 |
+
'seq': beam_seq_table[divm][b, vix].clone(),
|
183 |
+
'logps': beam_seq_logprobs_table[divm][b, vix].clone(),
|
184 |
+
'unaug_p': beam_seq_logprobs_table[divm][b, vix].sum().item(),
|
185 |
+
'p': beam_logprobs_sum_table[divm][b, vix].item()
|
186 |
+
}
|
187 |
+
final_beam['p'] = length_penalty(t - divm + 1, final_beam['p'])
|
188 |
+
done_beams_table[b][divm].append(final_beam)
|
189 |
+
beam_logprobs_sum_table[divm][b, is_end] -= 1000
|
190 |
+
|
191 |
+
# move the current group one step forward in time
|
192 |
+
|
193 |
+
it = beam_seq_table[divm][:, :, t - divm].reshape(-1)
|
194 |
+
logprobs_table[divm], state_table[divm] = self.get_logprobs_state(it.cuda(), *(
|
195 |
+
args[divm] + [state_table[divm]]))
|
196 |
+
logprobs_table[divm] = F.log_softmax(logprobs_table[divm] / temperature, dim=-1)
|
197 |
+
|
198 |
+
# all beams are sorted by their log-probabilities
|
199 |
+
done_beams_table = [[sorted(done_beams_table[b][i], key=lambda x: -x['p'])[:bdash] for i in range(group_size)]
|
200 |
+
for b in range(batch_size)]
|
201 |
+
done_beams = [sum(_, []) for _ in done_beams_table]
|
202 |
+
return done_beams
|
203 |
+
|
204 |
+
def old_beam_search(self, init_state, init_logprobs, *args, **kwargs):
|
205 |
+
|
206 |
+
# function computes the similarity score to be augmented
|
207 |
+
def add_diversity(beam_seq_table, logprobsf, t, divm, diversity_lambda, bdash):
|
208 |
+
local_time = t - divm
|
209 |
+
unaug_logprobsf = logprobsf.clone()
|
210 |
+
for prev_choice in range(divm):
|
211 |
+
prev_decisions = beam_seq_table[prev_choice][local_time]
|
212 |
+
for sub_beam in range(bdash):
|
213 |
+
for prev_labels in range(bdash):
|
214 |
+
logprobsf[sub_beam][prev_decisions[prev_labels]] = logprobsf[sub_beam][prev_decisions[
|
215 |
+
prev_labels]] - diversity_lambda
|
216 |
+
return unaug_logprobsf
|
217 |
+
|
218 |
+
# does one step of classical beam search
|
219 |
+
|
220 |
+
def beam_step(logprobsf, unaug_logprobsf, beam_size, t, beam_seq, beam_seq_logprobs, beam_logprobs_sum, state):
|
221 |
+
# INPUTS:
|
222 |
+
# logprobsf: probabilities augmented after diversity
|
223 |
+
# beam_size: obvious
|
224 |
+
# t : time instant
|
225 |
+
# beam_seq : tensor contanining the beams
|
226 |
+
# beam_seq_logprobs: tensor contanining the beam logprobs
|
227 |
+
# beam_logprobs_sum: tensor contanining joint logprobs
|
228 |
+
# OUPUTS:
|
229 |
+
# beam_seq : tensor containing the word indices of the decoded captions
|
230 |
+
# beam_seq_logprobs : log-probability of each decision made, same size as beam_seq
|
231 |
+
# beam_logprobs_sum : joint log-probability of each beam
|
232 |
+
|
233 |
+
ys, ix = torch.sort(logprobsf, 1, True)
|
234 |
+
candidates = []
|
235 |
+
cols = min(beam_size, ys.size(1))
|
236 |
+
rows = beam_size
|
237 |
+
if t == 0:
|
238 |
+
rows = 1
|
239 |
+
for c in range(cols): # for each column (word, essentially)
|
240 |
+
for q in range(rows): # for each beam expansion
|
241 |
+
# compute logprob of expanding beam q with word in (sorted) position c
|
242 |
+
local_logprob = ys[q, c].item()
|
243 |
+
candidate_logprob = beam_logprobs_sum[q] + local_logprob
|
244 |
+
# local_unaug_logprob = unaug_logprobsf[q,ix[q,c]]
|
245 |
+
candidates.append({'c': ix[q, c], 'q': q, 'p': candidate_logprob, 'r': unaug_logprobsf[q]})
|
246 |
+
candidates = sorted(candidates, key=lambda x: -x['p'])
|
247 |
+
|
248 |
+
new_state = [_.clone() for _ in state]
|
249 |
+
# beam_seq_prev, beam_seq_logprobs_prev
|
250 |
+
if t >= 1:
|
251 |
+
# we''ll need these as reference when we fork beams around
|
252 |
+
beam_seq_prev = beam_seq[:t].clone()
|
253 |
+
beam_seq_logprobs_prev = beam_seq_logprobs[:t].clone()
|
254 |
+
for vix in range(beam_size):
|
255 |
+
v = candidates[vix]
|
256 |
+
# fork beam index q into index vix
|
257 |
+
if t >= 1:
|
258 |
+
beam_seq[:t, vix] = beam_seq_prev[:, v['q']]
|
259 |
+
beam_seq_logprobs[:t, vix] = beam_seq_logprobs_prev[:, v['q']]
|
260 |
+
# rearrange recurrent states
|
261 |
+
for state_ix in range(len(new_state)):
|
262 |
+
# copy over state in previous beam q to new beam at vix
|
263 |
+
new_state[state_ix][:, vix] = state[state_ix][:, v['q']] # dimension one is time step
|
264 |
+
# append new end terminal at the end of this beam
|
265 |
+
beam_seq[t, vix] = v['c'] # c'th word is the continuation
|
266 |
+
beam_seq_logprobs[t, vix] = v['r'] # the raw logprob here
|
267 |
+
beam_logprobs_sum[vix] = v['p'] # the new (sum) logprob along this beam
|
268 |
+
state = new_state
|
269 |
+
return beam_seq, beam_seq_logprobs, beam_logprobs_sum, state, candidates
|
270 |
+
|
271 |
+
# Start diverse_beam_search
|
272 |
+
opt = kwargs['opt']
|
273 |
+
temperature = opt.get('temperature', 1) # This should not affect beam search, but will affect dbs
|
274 |
+
beam_size = opt.get('beam_size', 10)
|
275 |
+
group_size = opt.get('group_size', 1)
|
276 |
+
diversity_lambda = opt.get('diversity_lambda', 0.5)
|
277 |
+
decoding_constraint = opt.get('decoding_constraint', 0)
|
278 |
+
suppress_UNK = opt.get('suppress_UNK', 0)
|
279 |
+
length_penalty = utils.penalty_builder(opt.get('length_penalty', ''))
|
280 |
+
bdash = beam_size // group_size # beam per group
|
281 |
+
|
282 |
+
# INITIALIZATIONS
|
283 |
+
beam_seq_table = [torch.LongTensor(self.max_seq_length, bdash).zero_() for _ in range(group_size)]
|
284 |
+
beam_seq_logprobs_table = [torch.FloatTensor(self.max_seq_length, bdash, self.vocab_size + 1).zero_() for _ in
|
285 |
+
range(group_size)]
|
286 |
+
beam_logprobs_sum_table = [torch.zeros(bdash) for _ in range(group_size)]
|
287 |
+
|
288 |
+
# logprobs # logprobs predicted in last time step, shape (beam_size, vocab_size+1)
|
289 |
+
done_beams_table = [[] for _ in range(group_size)]
|
290 |
+
# state_table = [list(torch.unbind(_)) for _ in torch.stack(init_state).chunk(group_size, 2)]
|
291 |
+
state_table = list(zip(*[_.chunk(group_size, 1) for _ in init_state]))
|
292 |
+
logprobs_table = list(init_logprobs.chunk(group_size, 0))
|
293 |
+
# END INIT
|
294 |
+
|
295 |
+
# Chunk elements in the args
|
296 |
+
args = list(args)
|
297 |
+
if self.__class__.__name__ == 'AttEnsemble':
|
298 |
+
args = [[_.chunk(group_size) if _ is not None else [None] * group_size for _ in args_] for args_ in
|
299 |
+
args] # arg_name, model_name, group_name
|
300 |
+
args = [[[args[j][i][k] for i in range(len(self.models))] for j in range(len(args))] for k in
|
301 |
+
range(group_size)] # group_name, arg_name, model_name
|
302 |
+
else:
|
303 |
+
args = [_.chunk(group_size) if _ is not None else [None] * group_size for _ in args]
|
304 |
+
args = [[args[i][j] for i in range(len(args))] for j in range(group_size)]
|
305 |
+
|
306 |
+
for t in range(self.max_seq_length + group_size - 1):
|
307 |
+
for divm in range(group_size):
|
308 |
+
if t >= divm and t <= self.max_seq_length + divm - 1:
|
309 |
+
# add diversity
|
310 |
+
logprobsf = logprobs_table[divm].float()
|
311 |
+
# suppress previous word
|
312 |
+
if decoding_constraint and t - divm > 0:
|
313 |
+
logprobsf.scatter_(1, beam_seq_table[divm][t - divm - 1].unsqueeze(1).cuda(), float('-inf'))
|
314 |
+
# suppress UNK tokens in the decoding
|
315 |
+
if suppress_UNK and hasattr(self, 'vocab') and self.vocab[str(logprobsf.size(1) - 1)] == 'UNK':
|
316 |
+
logprobsf[:, logprobsf.size(1) - 1] = logprobsf[:, logprobsf.size(1) - 1] - 1000
|
317 |
+
# diversity is added here
|
318 |
+
# the function directly modifies the logprobsf values and hence, we need to return
|
319 |
+
# the unaugmented ones for sorting the candidates in the end. # for historical
|
320 |
+
# reasons :-)
|
321 |
+
unaug_logprobsf = add_diversity(beam_seq_table, logprobsf, t, divm, diversity_lambda, bdash)
|
322 |
+
|
323 |
+
# infer new beams
|
324 |
+
beam_seq_table[divm], \
|
325 |
+
beam_seq_logprobs_table[divm], \
|
326 |
+
beam_logprobs_sum_table[divm], \
|
327 |
+
state_table[divm], \
|
328 |
+
candidates_divm = beam_step(logprobsf,
|
329 |
+
unaug_logprobsf,
|
330 |
+
bdash,
|
331 |
+
t - divm,
|
332 |
+
beam_seq_table[divm],
|
333 |
+
beam_seq_logprobs_table[divm],
|
334 |
+
beam_logprobs_sum_table[divm],
|
335 |
+
state_table[divm])
|
336 |
+
|
337 |
+
# if time's up... or if end token is reached then copy beams
|
338 |
+
for vix in range(bdash):
|
339 |
+
if beam_seq_table[divm][t - divm, vix] == self.eos_idx or t == self.max_seq_length + divm - 1:
|
340 |
+
final_beam = {
|
341 |
+
'seq': beam_seq_table[divm][:, vix].clone(),
|
342 |
+
'logps': beam_seq_logprobs_table[divm][:, vix].clone(),
|
343 |
+
'unaug_p': beam_seq_logprobs_table[divm][:, vix].sum().item(),
|
344 |
+
'p': beam_logprobs_sum_table[divm][vix].item()
|
345 |
+
}
|
346 |
+
final_beam['p'] = length_penalty(t - divm + 1, final_beam['p'])
|
347 |
+
done_beams_table[divm].append(final_beam)
|
348 |
+
# don't continue beams from finished sequences
|
349 |
+
beam_logprobs_sum_table[divm][vix] = -1000
|
350 |
+
|
351 |
+
# move the current group one step forward in time
|
352 |
+
|
353 |
+
it = beam_seq_table[divm][t - divm]
|
354 |
+
logprobs_table[divm], state_table[divm] = self.get_logprobs_state(it.cuda(), *(
|
355 |
+
args[divm] + [state_table[divm]]))
|
356 |
+
logprobs_table[divm] = F.log_softmax(logprobs_table[divm] / temperature, dim=-1)
|
357 |
+
|
358 |
+
# all beams are sorted by their log-probabilities
|
359 |
+
done_beams_table = [sorted(done_beams_table[i], key=lambda x: -x['p'])[:bdash] for i in range(group_size)]
|
360 |
+
done_beams = sum(done_beams_table, [])
|
361 |
+
return done_beams
|
362 |
+
|
363 |
+
def sample_next_word(self, logprobs, sample_method, temperature):
|
364 |
+
if sample_method == 'greedy':
|
365 |
+
sampleLogprobs, it = torch.max(logprobs.data, 1)
|
366 |
+
it = it.view(-1).long()
|
367 |
+
elif sample_method == 'gumbel': # gumbel softmax
|
368 |
+
def sample_gumbel(shape, eps=1e-20):
|
369 |
+
U = torch.rand(shape).cuda()
|
370 |
+
return -torch.log(-torch.log(U + eps) + eps)
|
371 |
+
|
372 |
+
def gumbel_softmax_sample(logits, temperature):
|
373 |
+
y = logits + sample_gumbel(logits.size())
|
374 |
+
return F.log_softmax(y / temperature, dim=-1)
|
375 |
+
|
376 |
+
_logprobs = gumbel_softmax_sample(logprobs, temperature)
|
377 |
+
_, it = torch.max(_logprobs.data, 1)
|
378 |
+
sampleLogprobs = logprobs.gather(1, it.unsqueeze(1)) # gather the logprobs at sampled positions
|
379 |
+
else:
|
380 |
+
logprobs = logprobs / temperature
|
381 |
+
if sample_method.startswith('top'): # topk sampling
|
382 |
+
top_num = float(sample_method[3:])
|
383 |
+
if 0 < top_num < 1:
|
384 |
+
# nucleus sampling from # The Curious Case of Neural Text Degeneration
|
385 |
+
probs = F.softmax(logprobs, dim=1)
|
386 |
+
sorted_probs, sorted_indices = torch.sort(probs, descending=True, dim=1)
|
387 |
+
_cumsum = sorted_probs.cumsum(1)
|
388 |
+
mask = _cumsum < top_num
|
389 |
+
mask = torch.cat([torch.ones_like(mask[:, :1]), mask[:, :-1]], 1)
|
390 |
+
sorted_probs = sorted_probs * mask.float()
|
391 |
+
sorted_probs = sorted_probs / sorted_probs.sum(1, keepdim=True)
|
392 |
+
logprobs.scatter_(1, sorted_indices, sorted_probs.log())
|
393 |
+
else:
|
394 |
+
the_k = int(top_num)
|
395 |
+
tmp = torch.empty_like(logprobs).fill_(float('-inf'))
|
396 |
+
topk, indices = torch.topk(logprobs, the_k, dim=1)
|
397 |
+
tmp = tmp.scatter(1, indices, topk)
|
398 |
+
logprobs = tmp
|
399 |
+
it = torch.distributions.Categorical(logits=logprobs.detach()).sample()
|
400 |
+
sampleLogprobs = logprobs.gather(1, it.unsqueeze(1)) # gather the logprobs at sampled positions
|
401 |
+
return it, sampleLogprobs
|
modules/config.pkl
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:c454e6bddb15af52c82734f1796391bf3a10a6c5533ea095de06f661ebb858bb
|
3 |
+
size 1744
|
modules/dataloader.py
ADDED
@@ -0,0 +1,59 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import pdb
|
2 |
+
|
3 |
+
import torch
|
4 |
+
from torch.utils.data import DataLoader
|
5 |
+
from .dataset import IuxrayMultiImageDataset, MimiccxrSingleImageDataset
|
6 |
+
from medclip import MedCLIPProcessor
|
7 |
+
import numpy as np
|
8 |
+
|
9 |
+
class R2DataLoader(DataLoader):
|
10 |
+
def __init__(self, args, tokenizer, split, shuffle):
|
11 |
+
self.args = args
|
12 |
+
self.dataset_name = args.dataset
|
13 |
+
self.batch_size = args.bs
|
14 |
+
self.shuffle = shuffle
|
15 |
+
self.num_workers = args.num_workers
|
16 |
+
self.tokenizer = tokenizer
|
17 |
+
self.split = split
|
18 |
+
self.processor = MedCLIPProcessor()
|
19 |
+
|
20 |
+
if self.dataset_name == 'iu_xray':
|
21 |
+
self.dataset = IuxrayMultiImageDataset(self.args, self.tokenizer, self.split, self.processor)
|
22 |
+
else:
|
23 |
+
self.dataset = MimiccxrSingleImageDataset(self.args, self.tokenizer, self.split, self.processor)
|
24 |
+
|
25 |
+
self.init_kwargs = {
|
26 |
+
'dataset': self.dataset,
|
27 |
+
'batch_size': self.batch_size,
|
28 |
+
'shuffle': self.shuffle,
|
29 |
+
'collate_fn': self.collate_fn,
|
30 |
+
'num_workers': self.num_workers
|
31 |
+
}
|
32 |
+
super().__init__(**self.init_kwargs)
|
33 |
+
|
34 |
+
@staticmethod
|
35 |
+
def collate_fn(data):
|
36 |
+
image_id_batch, image_batch, report_ids_batch, report_masks_batch, processor_ids_batch, processor_mask_batch, seq_lengths_batch, processor_lenghts_batch = zip(*data)
|
37 |
+
image_batch = torch.stack(image_batch, 0)
|
38 |
+
|
39 |
+
max_seq_length = max(seq_lengths_batch)
|
40 |
+
target_batch = np.zeros((len(report_ids_batch), max_seq_length), dtype=int)
|
41 |
+
target_masks_batch = np.zeros((len(report_ids_batch), max_seq_length), dtype=int)
|
42 |
+
|
43 |
+
max_processor_length = max(processor_lenghts_batch)
|
44 |
+
target_processor_batch = np.zeros((len(processor_ids_batch), max_processor_length), dtype=int)
|
45 |
+
target_processor_mask_batch = np.zeros((len(processor_mask_batch), max_processor_length), dtype=int)
|
46 |
+
|
47 |
+
for i, report_ids in enumerate(report_ids_batch):
|
48 |
+
target_batch[i, :len(report_ids)] = report_ids
|
49 |
+
|
50 |
+
for i, report_masks in enumerate(report_masks_batch):
|
51 |
+
target_masks_batch[i, :len(report_masks)] = report_masks
|
52 |
+
|
53 |
+
for i, report_ids in enumerate(processor_ids_batch):
|
54 |
+
target_processor_batch[i, :len(report_ids)] = report_ids
|
55 |
+
|
56 |
+
for i, report_masks in enumerate(processor_mask_batch):
|
57 |
+
target_processor_mask_batch[i, :len(report_masks)] = report_masks
|
58 |
+
|
59 |
+
return image_id_batch, image_batch, torch.LongTensor(target_batch), torch.FloatTensor(target_masks_batch), torch.FloatTensor(target_processor_batch), torch.FloatTensor(target_processor_mask_batch)
|
modules/dataloaders.py
ADDED
@@ -0,0 +1,62 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import numpy as np
|
3 |
+
from torchvision import transforms
|
4 |
+
from torch.utils.data import DataLoader
|
5 |
+
from .datasets import IuxrayMultiImageDataset, MimiccxrSingleImageDataset
|
6 |
+
|
7 |
+
|
8 |
+
class R2DataLoader(DataLoader):
|
9 |
+
def __init__(self, args, tokenizer, split, shuffle):
|
10 |
+
self.args = args
|
11 |
+
self.dataset_name = args.dataset_name
|
12 |
+
self.batch_size = args.batch_size
|
13 |
+
self.shuffle = shuffle
|
14 |
+
self.num_workers = args.num_workers
|
15 |
+
self.tokenizer = tokenizer
|
16 |
+
self.split = split
|
17 |
+
|
18 |
+
if split == 'train':
|
19 |
+
self.transform = transforms.Compose([
|
20 |
+
transforms.Resize(256),
|
21 |
+
transforms.RandomCrop(224),
|
22 |
+
transforms.RandomHorizontalFlip(),
|
23 |
+
transforms.ToTensor(),
|
24 |
+
transforms.Normalize((0.485, 0.456, 0.406),
|
25 |
+
(0.229, 0.224, 0.225))])
|
26 |
+
else:
|
27 |
+
self.transform = transforms.Compose([
|
28 |
+
transforms.Resize((224, 224)),
|
29 |
+
transforms.ToTensor(),
|
30 |
+
transforms.Normalize((0.485, 0.456, 0.406),
|
31 |
+
(0.229, 0.224, 0.225))])
|
32 |
+
|
33 |
+
if self.dataset_name == 'iu_xray':
|
34 |
+
self.dataset = IuxrayMultiImageDataset(self.args, self.tokenizer, self.split, transform=self.transform)
|
35 |
+
else:
|
36 |
+
self.dataset = MimiccxrSingleImageDataset(self.args, self.tokenizer, self.split, transform=self.transform)
|
37 |
+
|
38 |
+
self.init_kwargs = {
|
39 |
+
'dataset': self.dataset,
|
40 |
+
'batch_size': self.batch_size,
|
41 |
+
'shuffle': self.shuffle,
|
42 |
+
'collate_fn': self.collate_fn,
|
43 |
+
'num_workers': self.num_workers
|
44 |
+
}
|
45 |
+
super().__init__(**self.init_kwargs)
|
46 |
+
|
47 |
+
@staticmethod
|
48 |
+
def collate_fn(data):
|
49 |
+
images_id, images, reports_ids, reports_masks, seq_lengths = zip(*data)
|
50 |
+
images = torch.stack(images, 0)
|
51 |
+
max_seq_length = max(seq_lengths)
|
52 |
+
|
53 |
+
targets = np.zeros((len(reports_ids), max_seq_length), dtype=int)
|
54 |
+
targets_masks = np.zeros((len(reports_ids), max_seq_length), dtype=int)
|
55 |
+
|
56 |
+
for i, report_ids in enumerate(reports_ids):
|
57 |
+
targets[i, :len(report_ids)] = report_ids
|
58 |
+
|
59 |
+
for i, report_masks in enumerate(reports_masks):
|
60 |
+
targets_masks[i, :len(report_masks)] = report_masks
|
61 |
+
|
62 |
+
return images_id, images, torch.LongTensor(targets), torch.FloatTensor(targets_masks)
|
modules/dataset.py
ADDED
@@ -0,0 +1,68 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
from PIL import Image
|
3 |
+
import json
|
4 |
+
from torch.utils.data import Dataset
|
5 |
+
import numpy as np
|
6 |
+
import torch
|
7 |
+
|
8 |
+
|
9 |
+
class BaseDataset(Dataset):
|
10 |
+
def __init__(self, args, tokenizer, split, processor):
|
11 |
+
self.image_dir = args.image_dir
|
12 |
+
self.ann_path = args.json_path
|
13 |
+
self.max_seq_length = args.max_seq_length
|
14 |
+
self.split = split
|
15 |
+
self.tokenizer = tokenizer
|
16 |
+
self.ann = json.loads(open(self.ann_path, 'r').read())
|
17 |
+
self.examples = self.ann[self.split]
|
18 |
+
self.processor = processor
|
19 |
+
|
20 |
+
def preprocess_text(self, text):
|
21 |
+
ids = self.tokenizer(text)[:self.max_seq_length]
|
22 |
+
mask = [1] * len(ids)
|
23 |
+
text_inputs = self.processor(text=text, return_tensors="pt",truncation=True, padding=False, max_length=self.max_seq_length)
|
24 |
+
processor_ids = text_inputs['input_ids'].squeeze(0).tolist()
|
25 |
+
processor_mask = text_inputs['attention_mask'].squeeze(0).tolist()
|
26 |
+
return ids, mask, processor_ids, processor_mask
|
27 |
+
|
28 |
+
def __len__(self):
|
29 |
+
return len(self.examples)
|
30 |
+
|
31 |
+
|
32 |
+
class IuxrayMultiImageDataset(BaseDataset):
|
33 |
+
def __getitem__(self, idx):
|
34 |
+
example = self.examples[idx]
|
35 |
+
report = example['report']
|
36 |
+
report_ids, report_masks, processor_ids, processor_mask = self.preprocess_text(report)
|
37 |
+
|
38 |
+
image_id = example['id']
|
39 |
+
image_path = example['image_path']
|
40 |
+
image_1 = Image.open(os.path.join(self.image_dir, image_path[0])).convert('RGB')
|
41 |
+
image_2 = Image.open(os.path.join(self.image_dir, image_path[1])).convert('RGB')
|
42 |
+
# MedCLIP processing
|
43 |
+
image_inputs_1 = self.processor(images=image_1, return_tensors="pt")
|
44 |
+
image_inputs_2 = self.processor(images=image_2, return_tensors="pt")
|
45 |
+
image = torch.stack((image_inputs_1.pixel_values[0], image_inputs_2.pixel_values[0]), 0)
|
46 |
+
|
47 |
+
seq_length = len(report_ids)
|
48 |
+
processor_length = len(processor_ids)
|
49 |
+
sample = (image_id, image, report_ids, report_masks, processor_ids, processor_mask, seq_length, processor_length)
|
50 |
+
return sample
|
51 |
+
|
52 |
+
|
53 |
+
class MimiccxrSingleImageDataset(BaseDataset):
|
54 |
+
def __getitem__(self, idx):
|
55 |
+
example = self.examples[idx]
|
56 |
+
report = example['report']
|
57 |
+
report_ids, report_masks, processor_ids, processor_mask = self.preprocess_text(report)
|
58 |
+
|
59 |
+
image_id = example['id']
|
60 |
+
image_path = example['image_path']
|
61 |
+
image = Image.open(os.path.join(self.image_dir, image_path[0])).convert('RGB')
|
62 |
+
image_inputs = self.processor(images=image, return_tensors="pt")
|
63 |
+
image = image_inputs.pixel_values[0]
|
64 |
+
|
65 |
+
seq_length = len(report_ids)
|
66 |
+
processor_length = len(processor_ids)
|
67 |
+
sample = (image_id, image, report_ids, report_masks, processor_ids, processor_mask, seq_length, processor_length)
|
68 |
+
return sample
|
modules/datasets.py
ADDED
@@ -0,0 +1,57 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import json
|
3 |
+
import torch
|
4 |
+
from PIL import Image
|
5 |
+
from torch.utils.data import Dataset
|
6 |
+
|
7 |
+
|
8 |
+
class BaseDataset(Dataset):
|
9 |
+
def __init__(self, args, tokenizer, split, transform=None):
|
10 |
+
self.image_dir = args.image_dir
|
11 |
+
self.ann_path = args.ann_path
|
12 |
+
self.max_seq_length = args.max_seq_length
|
13 |
+
self.split = split
|
14 |
+
self.tokenizer = tokenizer
|
15 |
+
self.transform = transform
|
16 |
+
self.ann = json.loads(open(self.ann_path, 'r').read())
|
17 |
+
|
18 |
+
self.examples = self.ann[self.split]
|
19 |
+
for i in range(len(self.examples)):
|
20 |
+
self.examples[i]['ids'] = tokenizer(self.examples[i]['report'])[:self.max_seq_length]
|
21 |
+
self.examples[i]['mask'] = [1] * len(self.examples[i]['ids'])
|
22 |
+
|
23 |
+
def __len__(self):
|
24 |
+
return len(self.examples)
|
25 |
+
|
26 |
+
|
27 |
+
class IuxrayMultiImageDataset(BaseDataset):
|
28 |
+
def __getitem__(self, idx):
|
29 |
+
example = self.examples[idx]
|
30 |
+
image_id = example['id']
|
31 |
+
image_path = example['image_path']
|
32 |
+
image_1 = Image.open(os.path.join(self.image_dir, image_path[0])).convert('RGB')
|
33 |
+
image_2 = Image.open(os.path.join(self.image_dir, image_path[1])).convert('RGB')
|
34 |
+
if self.transform is not None:
|
35 |
+
image_1 = self.transform(image_1)
|
36 |
+
image_2 = self.transform(image_2)
|
37 |
+
image = torch.stack((image_1, image_2), 0)
|
38 |
+
report_ids = example['ids']
|
39 |
+
report_masks = example['mask']
|
40 |
+
seq_length = len(report_ids)
|
41 |
+
sample = (image_id, image, report_ids, report_masks, seq_length)
|
42 |
+
return sample
|
43 |
+
|
44 |
+
|
45 |
+
class MimiccxrSingleImageDataset(BaseDataset):
|
46 |
+
def __getitem__(self, idx):
|
47 |
+
example = self.examples[idx]
|
48 |
+
image_id = example['id']
|
49 |
+
image_path = example['image_path']
|
50 |
+
image = Image.open(os.path.join(self.image_dir, image_path[0])).convert('RGB')
|
51 |
+
if self.transform is not None:
|
52 |
+
image = self.transform(image)
|
53 |
+
report_ids = example['ids']
|
54 |
+
report_masks = example['mask']
|
55 |
+
seq_length = len(report_ids)
|
56 |
+
sample = (image_id, image, report_ids, report_masks, seq_length)
|
57 |
+
return sample
|
modules/decoder.py
ADDED
@@ -0,0 +1,50 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
import torch
|
3 |
+
import torch.nn as nn
|
4 |
+
import pickle
|
5 |
+
from typing import Tuple
|
6 |
+
from transformers import GPT2LMHeadModel
|
7 |
+
from .att_models import AttModel
|
8 |
+
import pdb
|
9 |
+
|
10 |
+
class MLP(nn.Module):
|
11 |
+
|
12 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
13 |
+
return self.model(x)
|
14 |
+
|
15 |
+
def __init__(self, sizes: Tuple[int, ...], bias=True, act=nn.Tanh):
|
16 |
+
super(MLP, self).__init__()
|
17 |
+
layers = []
|
18 |
+
for i in range(len(sizes) - 1):
|
19 |
+
layers.append(nn.Linear(sizes[i], sizes[i + 1], bias=bias))
|
20 |
+
if i < len(sizes) - 2:
|
21 |
+
layers.append(act())
|
22 |
+
self.model = nn.Sequential(*layers)
|
23 |
+
|
24 |
+
class DeCap(AttModel):
|
25 |
+
|
26 |
+
def __init__(self, args, tokenizer):
|
27 |
+
super(DeCap, self).__init__(args, tokenizer)
|
28 |
+
|
29 |
+
# decoder: 4 layers transformer with 4 attention heads
|
30 |
+
# the decoder is not pretrained
|
31 |
+
with open('./decoder_config/decoder_config.pkl', 'rb') as f:
|
32 |
+
config = pickle.load(f)
|
33 |
+
# Change the parameters you need
|
34 |
+
config.vocab_size = tokenizer.get_vocab_size()
|
35 |
+
config.bos_token_id = tokenizer.bos_token_id
|
36 |
+
config.eos_token_id = tokenizer.eos_token_id
|
37 |
+
self.decoder = GPT2LMHeadModel(config)
|
38 |
+
self.embedding_size = self.decoder.transformer.wte.weight.shape[1]
|
39 |
+
self.prefix_size = 512
|
40 |
+
self.clip_project = MLP((self.prefix_size, self.embedding_size))
|
41 |
+
|
42 |
+
def _forward(self, clip_features, gpt_tokens):
|
43 |
+
|
44 |
+
embedding_text = self.decoder.transformer.wte(gpt_tokens)
|
45 |
+
embedding_clip = self.clip_project(clip_features)
|
46 |
+
embedding_clip = embedding_clip.reshape(-1, 1, self.embedding_size)
|
47 |
+
embedding_cat = torch.cat([embedding_clip, embedding_text], dim=1)
|
48 |
+
out = self.decoder(inputs_embeds=embedding_cat)
|
49 |
+
return out
|
50 |
+
|
modules/encoder_decoder.py
ADDED
@@ -0,0 +1,391 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from __future__ import absolute_import
|
2 |
+
from __future__ import division
|
3 |
+
from __future__ import print_function
|
4 |
+
|
5 |
+
import copy
|
6 |
+
import math
|
7 |
+
|
8 |
+
import numpy as np
|
9 |
+
import torch
|
10 |
+
import torch.nn as nn
|
11 |
+
import torch.nn.functional as F
|
12 |
+
|
13 |
+
from .att_model import pack_wrapper, AttModel
|
14 |
+
|
15 |
+
|
16 |
+
def clones(module, N):
|
17 |
+
return nn.ModuleList([copy.deepcopy(module) for _ in range(N)])
|
18 |
+
|
19 |
+
|
20 |
+
def attention(query, key, value, mask=None, dropout=None):
|
21 |
+
d_k = query.size(-1)
|
22 |
+
scores = torch.matmul(query, key.transpose(-2, -1)) / math.sqrt(d_k)
|
23 |
+
if mask is not None:
|
24 |
+
scores = scores.masked_fill(mask == 0, -1e9)
|
25 |
+
p_attn = F.softmax(scores, dim=-1)
|
26 |
+
if dropout is not None:
|
27 |
+
p_attn = dropout(p_attn)
|
28 |
+
return torch.matmul(p_attn, value), p_attn
|
29 |
+
|
30 |
+
|
31 |
+
def subsequent_mask(size):
|
32 |
+
attn_shape = (1, size, size)
|
33 |
+
subsequent_mask = np.triu(np.ones(attn_shape), k=1).astype('uint8')
|
34 |
+
return torch.from_numpy(subsequent_mask) == 0
|
35 |
+
|
36 |
+
|
37 |
+
class Transformer(nn.Module):
|
38 |
+
def __init__(self, encoder, decoder, src_embed, tgt_embed, rm):
|
39 |
+
super(Transformer, self).__init__()
|
40 |
+
self.encoder = encoder
|
41 |
+
self.decoder = decoder
|
42 |
+
self.src_embed = src_embed
|
43 |
+
self.tgt_embed = tgt_embed
|
44 |
+
self.rm = rm
|
45 |
+
|
46 |
+
def forward(self, src, tgt, src_mask, tgt_mask):
|
47 |
+
return self.decode(self.encode(src, src_mask), src_mask, tgt, tgt_mask)
|
48 |
+
|
49 |
+
def encode(self, src, src_mask):
|
50 |
+
return self.encoder(self.src_embed(src), src_mask)
|
51 |
+
|
52 |
+
def decode(self, hidden_states, src_mask, tgt, tgt_mask):
|
53 |
+
memory = self.rm.init_memory(hidden_states.size(0)).to(hidden_states)
|
54 |
+
memory = self.rm(self.tgt_embed(tgt), memory)
|
55 |
+
return self.decoder(self.tgt_embed(tgt), hidden_states, src_mask, tgt_mask, memory)
|
56 |
+
|
57 |
+
|
58 |
+
class Encoder(nn.Module):
|
59 |
+
def __init__(self, layer, N):
|
60 |
+
super(Encoder, self).__init__()
|
61 |
+
self.layers = clones(layer, N)
|
62 |
+
self.norm = LayerNorm(layer.d_model)
|
63 |
+
|
64 |
+
def forward(self, x, mask):
|
65 |
+
for layer in self.layers:
|
66 |
+
x = layer(x, mask)
|
67 |
+
return self.norm(x)
|
68 |
+
|
69 |
+
|
70 |
+
class EncoderLayer(nn.Module):
|
71 |
+
def __init__(self, d_model, self_attn, feed_forward, dropout):
|
72 |
+
super(EncoderLayer, self).__init__()
|
73 |
+
self.self_attn = self_attn
|
74 |
+
self.feed_forward = feed_forward
|
75 |
+
self.sublayer = clones(SublayerConnection(d_model, dropout), 2)
|
76 |
+
self.d_model = d_model
|
77 |
+
|
78 |
+
def forward(self, x, mask):
|
79 |
+
x = self.sublayer[0](x, lambda x: self.self_attn(x, x, x, mask))
|
80 |
+
return self.sublayer[1](x, self.feed_forward)
|
81 |
+
|
82 |
+
|
83 |
+
class SublayerConnection(nn.Module):
|
84 |
+
def __init__(self, d_model, dropout):
|
85 |
+
super(SublayerConnection, self).__init__()
|
86 |
+
self.norm = LayerNorm(d_model)
|
87 |
+
self.dropout = nn.Dropout(dropout)
|
88 |
+
|
89 |
+
def forward(self, x, sublayer):
|
90 |
+
return x + self.dropout(sublayer(self.norm(x)))
|
91 |
+
|
92 |
+
|
93 |
+
class LayerNorm(nn.Module):
|
94 |
+
def __init__(self, features, eps=1e-6):
|
95 |
+
super(LayerNorm, self).__init__()
|
96 |
+
self.gamma = nn.Parameter(torch.ones(features))
|
97 |
+
self.beta = nn.Parameter(torch.zeros(features))
|
98 |
+
self.eps = eps
|
99 |
+
|
100 |
+
def forward(self, x):
|
101 |
+
mean = x.mean(-1, keepdim=True)
|
102 |
+
std = x.std(-1, keepdim=True)
|
103 |
+
return self.gamma * (x - mean) / (std + self.eps) + self.beta
|
104 |
+
|
105 |
+
|
106 |
+
class Decoder(nn.Module):
|
107 |
+
def __init__(self, layer, N):
|
108 |
+
super(Decoder, self).__init__()
|
109 |
+
self.layers = clones(layer, N)
|
110 |
+
self.norm = LayerNorm(layer.d_model)
|
111 |
+
|
112 |
+
def forward(self, x, hidden_states, src_mask, tgt_mask, memory):
|
113 |
+
for layer in self.layers:
|
114 |
+
x = layer(x, hidden_states, src_mask, tgt_mask, memory)
|
115 |
+
return self.norm(x)
|
116 |
+
|
117 |
+
|
118 |
+
class DecoderLayer(nn.Module):
|
119 |
+
def __init__(self, d_model, self_attn, src_attn, feed_forward, dropout, rm_num_slots, rm_d_model):
|
120 |
+
super(DecoderLayer, self).__init__()
|
121 |
+
self.d_model = d_model
|
122 |
+
self.self_attn = self_attn
|
123 |
+
self.src_attn = src_attn
|
124 |
+
self.feed_forward = feed_forward
|
125 |
+
self.sublayer = clones(ConditionalSublayerConnection(d_model, dropout, rm_num_slots, rm_d_model), 3)
|
126 |
+
|
127 |
+
def forward(self, x, hidden_states, src_mask, tgt_mask, memory):
|
128 |
+
m = hidden_states
|
129 |
+
x = self.sublayer[0](x, lambda x: self.self_attn(x, x, x, tgt_mask), memory)
|
130 |
+
x = self.sublayer[1](x, lambda x: self.src_attn(x, m, m, src_mask), memory)
|
131 |
+
return self.sublayer[2](x, self.feed_forward, memory)
|
132 |
+
|
133 |
+
|
134 |
+
class ConditionalSublayerConnection(nn.Module):
|
135 |
+
def __init__(self, d_model, dropout, rm_num_slots, rm_d_model):
|
136 |
+
super(ConditionalSublayerConnection, self).__init__()
|
137 |
+
self.norm = ConditionalLayerNorm(d_model, rm_num_slots, rm_d_model)
|
138 |
+
self.dropout = nn.Dropout(dropout)
|
139 |
+
|
140 |
+
def forward(self, x, sublayer, memory):
|
141 |
+
return x + self.dropout(sublayer(self.norm(x, memory)))
|
142 |
+
|
143 |
+
|
144 |
+
class ConditionalLayerNorm(nn.Module):
|
145 |
+
def __init__(self, d_model, rm_num_slots, rm_d_model, eps=1e-6):
|
146 |
+
super(ConditionalLayerNorm, self).__init__()
|
147 |
+
self.gamma = nn.Parameter(torch.ones(d_model))
|
148 |
+
self.beta = nn.Parameter(torch.zeros(d_model))
|
149 |
+
self.rm_d_model = rm_d_model
|
150 |
+
self.rm_num_slots = rm_num_slots
|
151 |
+
self.eps = eps
|
152 |
+
|
153 |
+
self.mlp_gamma = nn.Sequential(nn.Linear(rm_num_slots * rm_d_model, d_model),
|
154 |
+
nn.ReLU(inplace=True),
|
155 |
+
nn.Linear(rm_d_model, rm_d_model))
|
156 |
+
|
157 |
+
self.mlp_beta = nn.Sequential(nn.Linear(rm_num_slots * rm_d_model, d_model),
|
158 |
+
nn.ReLU(inplace=True),
|
159 |
+
nn.Linear(d_model, d_model))
|
160 |
+
|
161 |
+
for m in self.modules():
|
162 |
+
if isinstance(m, nn.Linear):
|
163 |
+
nn.init.xavier_uniform_(m.weight)
|
164 |
+
nn.init.constant_(m.bias, 0.1)
|
165 |
+
|
166 |
+
def forward(self, x, memory):
|
167 |
+
mean = x.mean(-1, keepdim=True)
|
168 |
+
std = x.std(-1, keepdim=True)
|
169 |
+
delta_gamma = self.mlp_gamma(memory)
|
170 |
+
delta_beta = self.mlp_beta(memory)
|
171 |
+
gamma_hat = self.gamma.clone()
|
172 |
+
beta_hat = self.beta.clone()
|
173 |
+
gamma_hat = torch.stack([gamma_hat] * x.size(0), dim=0)
|
174 |
+
gamma_hat = torch.stack([gamma_hat] * x.size(1), dim=1)
|
175 |
+
beta_hat = torch.stack([beta_hat] * x.size(0), dim=0)
|
176 |
+
beta_hat = torch.stack([beta_hat] * x.size(1), dim=1)
|
177 |
+
gamma_hat += delta_gamma
|
178 |
+
beta_hat += delta_beta
|
179 |
+
return gamma_hat * (x - mean) / (std + self.eps) + beta_hat
|
180 |
+
|
181 |
+
|
182 |
+
class MultiHeadedAttention(nn.Module):
|
183 |
+
def __init__(self, h, d_model, dropout=0.1):
|
184 |
+
super(MultiHeadedAttention, self).__init__()
|
185 |
+
assert d_model % h == 0
|
186 |
+
self.d_k = d_model // h
|
187 |
+
self.h = h
|
188 |
+
self.linears = clones(nn.Linear(d_model, d_model), 4)
|
189 |
+
self.attn = None
|
190 |
+
self.dropout = nn.Dropout(p=dropout)
|
191 |
+
|
192 |
+
def forward(self, query, key, value, mask=None):
|
193 |
+
if mask is not None:
|
194 |
+
mask = mask.unsqueeze(1)
|
195 |
+
nbatches = query.size(0)
|
196 |
+
query, key, value = \
|
197 |
+
[l(x).view(nbatches, -1, self.h, self.d_k).transpose(1, 2)
|
198 |
+
for l, x in zip(self.linears, (query, key, value))]
|
199 |
+
|
200 |
+
x, self.attn = attention(query, key, value, mask=mask, dropout=self.dropout)
|
201 |
+
|
202 |
+
x = x.transpose(1, 2).contiguous().view(nbatches, -1, self.h * self.d_k)
|
203 |
+
return self.linears[-1](x)
|
204 |
+
|
205 |
+
|
206 |
+
class PositionwiseFeedForward(nn.Module):
|
207 |
+
def __init__(self, d_model, d_ff, dropout=0.1):
|
208 |
+
super(PositionwiseFeedForward, self).__init__()
|
209 |
+
self.w_1 = nn.Linear(d_model, d_ff)
|
210 |
+
self.w_2 = nn.Linear(d_ff, d_model)
|
211 |
+
self.dropout = nn.Dropout(dropout)
|
212 |
+
|
213 |
+
def forward(self, x):
|
214 |
+
return self.w_2(self.dropout(F.relu(self.w_1(x))))
|
215 |
+
|
216 |
+
|
217 |
+
class Embeddings(nn.Module):
|
218 |
+
def __init__(self, d_model, vocab):
|
219 |
+
super(Embeddings, self).__init__()
|
220 |
+
self.lut = nn.Embedding(vocab, d_model)
|
221 |
+
self.d_model = d_model
|
222 |
+
|
223 |
+
def forward(self, x):
|
224 |
+
return self.lut(x) * math.sqrt(self.d_model)
|
225 |
+
|
226 |
+
|
227 |
+
class PositionalEncoding(nn.Module):
|
228 |
+
def __init__(self, d_model, dropout, max_len=5000):
|
229 |
+
super(PositionalEncoding, self).__init__()
|
230 |
+
self.dropout = nn.Dropout(p=dropout)
|
231 |
+
|
232 |
+
pe = torch.zeros(max_len, d_model)
|
233 |
+
position = torch.arange(0, max_len).unsqueeze(1).float()
|
234 |
+
div_term = torch.exp(torch.arange(0, d_model, 2).float() *
|
235 |
+
-(math.log(10000.0) / d_model))
|
236 |
+
pe[:, 0::2] = torch.sin(position * div_term)
|
237 |
+
pe[:, 1::2] = torch.cos(position * div_term)
|
238 |
+
pe = pe.unsqueeze(0)
|
239 |
+
self.register_buffer('pe', pe)
|
240 |
+
|
241 |
+
def forward(self, x):
|
242 |
+
x = x + self.pe[:, :x.size(1)]
|
243 |
+
return self.dropout(x)
|
244 |
+
|
245 |
+
|
246 |
+
class RelationalMemory(nn.Module):
|
247 |
+
|
248 |
+
def __init__(self, num_slots, d_model, num_heads=1):
|
249 |
+
super(RelationalMemory, self).__init__()
|
250 |
+
self.num_slots = num_slots
|
251 |
+
self.num_heads = num_heads
|
252 |
+
self.d_model = d_model
|
253 |
+
|
254 |
+
self.attn = MultiHeadedAttention(num_heads, d_model)
|
255 |
+
self.mlp = nn.Sequential(nn.Linear(self.d_model, self.d_model),
|
256 |
+
nn.ReLU(),
|
257 |
+
nn.Linear(self.d_model, self.d_model),
|
258 |
+
nn.ReLU())
|
259 |
+
|
260 |
+
self.W = nn.Linear(self.d_model, self.d_model * 2)
|
261 |
+
self.U = nn.Linear(self.d_model, self.d_model * 2)
|
262 |
+
|
263 |
+
def init_memory(self, batch_size):
|
264 |
+
memory = torch.stack([torch.eye(self.num_slots)] * batch_size)
|
265 |
+
if self.d_model > self.num_slots:
|
266 |
+
diff = self.d_model - self.num_slots
|
267 |
+
pad = torch.zeros((batch_size, self.num_slots, diff))
|
268 |
+
memory = torch.cat([memory, pad], -1)
|
269 |
+
elif self.d_model < self.num_slots:
|
270 |
+
memory = memory[:, :, :self.d_model]
|
271 |
+
|
272 |
+
return memory
|
273 |
+
|
274 |
+
def forward_step(self, input, memory):
|
275 |
+
# print('inputinputinputinputinput',input.size())
|
276 |
+
# print('memorymemorymemorymemorymemorymemory',memory.size())
|
277 |
+
|
278 |
+
memory = memory.reshape(-1, self.num_slots, self.d_model)
|
279 |
+
# if input.shape[0]!=memory.shape[0]:
|
280 |
+
# input=input.repeat(round(memory.shape[0]/input.shape[0]),1)
|
281 |
+
q = memory
|
282 |
+
k = torch.cat([memory, input.unsqueeze(1)], 1)
|
283 |
+
v = torch.cat([memory, input.unsqueeze(1)], 1)
|
284 |
+
next_memory = memory + self.attn(q, k, v)
|
285 |
+
next_memory = next_memory + self.mlp(next_memory)
|
286 |
+
|
287 |
+
gates = self.W(input.unsqueeze(1)) + self.U(torch.tanh(memory))
|
288 |
+
gates = torch.split(gates, split_size_or_sections=self.d_model, dim=2)
|
289 |
+
input_gate, forget_gate = gates
|
290 |
+
input_gate = torch.sigmoid(input_gate)
|
291 |
+
forget_gate = torch.sigmoid(forget_gate)
|
292 |
+
|
293 |
+
next_memory = input_gate * torch.tanh(next_memory) + forget_gate * memory
|
294 |
+
next_memory = next_memory.reshape(-1, self.num_slots * self.d_model)
|
295 |
+
|
296 |
+
return next_memory
|
297 |
+
|
298 |
+
def forward(self, inputs, memory):
|
299 |
+
outputs = []
|
300 |
+
for i in range(inputs.shape[1]):
|
301 |
+
memory = self.forward_step(inputs[:, i], memory)
|
302 |
+
outputs.append(memory)
|
303 |
+
outputs = torch.stack(outputs, dim=1)
|
304 |
+
|
305 |
+
return outputs
|
306 |
+
|
307 |
+
|
308 |
+
class EncoderDecoder(AttModel):
|
309 |
+
|
310 |
+
def make_model(self, tgt_vocab):
|
311 |
+
c = copy.deepcopy
|
312 |
+
attn = MultiHeadedAttention(self.num_heads, self.d_model)
|
313 |
+
ff = PositionwiseFeedForward(self.d_model, self.d_ff, self.dropout)
|
314 |
+
position = PositionalEncoding(self.d_model, self.dropout)
|
315 |
+
rm = RelationalMemory(num_slots=self.rm_num_slots, d_model=self.rm_d_model, num_heads=self.rm_num_heads)
|
316 |
+
model = Transformer(
|
317 |
+
Encoder(EncoderLayer(self.d_model, c(attn), c(ff), self.dropout), self.num_layers),
|
318 |
+
Decoder(
|
319 |
+
DecoderLayer(self.d_model, c(attn), c(attn), c(ff), self.dropout, self.rm_num_slots, self.rm_d_model),
|
320 |
+
self.num_layers),
|
321 |
+
lambda x: x,
|
322 |
+
nn.Sequential(Embeddings(self.d_model, tgt_vocab), c(position)),
|
323 |
+
rm)
|
324 |
+
for p in model.parameters():
|
325 |
+
if p.dim() > 1:
|
326 |
+
nn.init.xavier_uniform_(p)
|
327 |
+
return model
|
328 |
+
|
329 |
+
def __init__(self, args, tokenizer):
|
330 |
+
super(EncoderDecoder, self).__init__(args, tokenizer)
|
331 |
+
self.args = args
|
332 |
+
self.num_layers = args.num_layers
|
333 |
+
self.d_model = args.d_model
|
334 |
+
self.d_ff = args.d_ff
|
335 |
+
self.num_heads = args.num_heads
|
336 |
+
self.dropout = args.dropout
|
337 |
+
self.rm_num_slots = args.rm_num_slots
|
338 |
+
self.rm_num_heads = args.rm_num_heads
|
339 |
+
self.rm_d_model = args.rm_d_model
|
340 |
+
|
341 |
+
tgt_vocab = self.vocab_size + 1
|
342 |
+
|
343 |
+
self.model = self.make_model(tgt_vocab)
|
344 |
+
self.logit = nn.Linear(args.d_model, tgt_vocab)
|
345 |
+
|
346 |
+
def init_hidden(self, bsz):
|
347 |
+
return []
|
348 |
+
|
349 |
+
def _prepare_feature(self, fc_feats, att_feats, att_masks):
|
350 |
+
|
351 |
+
att_feats, seq, att_masks, seq_mask = self._prepare_feature_forward(att_feats, att_masks)
|
352 |
+
memory = self.model.encode(att_feats, att_masks)
|
353 |
+
|
354 |
+
return fc_feats[..., :1], att_feats[..., :1], memory, att_masks
|
355 |
+
|
356 |
+
def _prepare_feature_forward(self, att_feats, att_masks=None, seq=None):
|
357 |
+
att_feats, att_masks = self.clip_att(att_feats, att_masks)
|
358 |
+
att_feats = pack_wrapper(self.att_embed, att_feats, att_masks)
|
359 |
+
|
360 |
+
if att_masks is None:
|
361 |
+
att_masks = att_feats.new_ones(att_feats.shape[:2], dtype=torch.long)
|
362 |
+
att_masks = att_masks.unsqueeze(-2)
|
363 |
+
|
364 |
+
if seq is not None:
|
365 |
+
# crop the last one
|
366 |
+
seq = seq[:, :-1]
|
367 |
+
seq_mask = (seq.data > 0)
|
368 |
+
seq_mask[:, 0] += True
|
369 |
+
|
370 |
+
seq_mask = seq_mask.unsqueeze(-2)
|
371 |
+
seq_mask = seq_mask & subsequent_mask(seq.size(-1)).to(seq_mask)
|
372 |
+
else:
|
373 |
+
seq_mask = None
|
374 |
+
|
375 |
+
return att_feats, seq, att_masks, seq_mask
|
376 |
+
|
377 |
+
def _forward(self, fc_feats, att_feats, seq, att_masks=None):
|
378 |
+
|
379 |
+
att_feats, seq, att_masks, seq_mask = self._prepare_feature_forward(att_feats, att_masks, seq)
|
380 |
+
out = self.model(att_feats, seq, att_masks, seq_mask)
|
381 |
+
outputs = F.log_softmax(self.logit(out), dim=-1)
|
382 |
+
return outputs
|
383 |
+
|
384 |
+
def core(self, it, fc_feats_ph, att_feats_ph, memory, state, mask):
|
385 |
+
|
386 |
+
if len(state) == 0:
|
387 |
+
ys = it.unsqueeze(1)
|
388 |
+
else:
|
389 |
+
ys = torch.cat([state[0][0], it.unsqueeze(1)], dim=1)
|
390 |
+
out = self.model.decode(memory, mask, ys, subsequent_mask(ys.size(1)).to(memory.device))
|
391 |
+
return out[:, -1], [ys.unsqueeze(0)]
|
modules/loss.py
ADDED
@@ -0,0 +1,22 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
|
4 |
+
|
5 |
+
class LanguageModelCriterion(nn.Module):
|
6 |
+
def __init__(self):
|
7 |
+
super(LanguageModelCriterion, self).__init__()
|
8 |
+
|
9 |
+
def forward(self, input, target, mask):
|
10 |
+
# truncate to the same size
|
11 |
+
target = target[:, :input.size(1)]
|
12 |
+
mask = mask[:, :input.size(1)]
|
13 |
+
output = -input.gather(2, target.long().unsqueeze(2)).squeeze(2) * mask
|
14 |
+
output = torch.sum(output) / torch.sum(mask)
|
15 |
+
|
16 |
+
return output
|
17 |
+
|
18 |
+
|
19 |
+
def compute_loss(output, reports_ids, reports_masks):
|
20 |
+
criterion = LanguageModelCriterion()
|
21 |
+
loss = criterion(output, reports_ids[:, 1:], reports_masks[:, 1:]).mean()
|
22 |
+
return loss
|
modules/metrics.py
ADDED
@@ -0,0 +1,33 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from pycocoevalcap.bleu.bleu import Bleu
|
2 |
+
from pycocoevalcap.meteor import Meteor
|
3 |
+
from pycocoevalcap.rouge import Rouge
|
4 |
+
|
5 |
+
|
6 |
+
def compute_scores(gts, res):
|
7 |
+
"""
|
8 |
+
Performs the MS COCO evaluation using the Python 3 implementation (https://github.com/salaniz/pycocoevalcap)
|
9 |
+
|
10 |
+
:param gts: Dictionary with the image ids and their gold captions,
|
11 |
+
:param res: Dictionary with the image ids ant their generated captions
|
12 |
+
:print: Evaluation score (the mean of the scores of all the instances) for each measure
|
13 |
+
"""
|
14 |
+
|
15 |
+
# Set up scorers
|
16 |
+
scorers = [
|
17 |
+
(Bleu(4), ["BLEU_1", "BLEU_2", "BLEU_3", "BLEU_4"]),
|
18 |
+
(Meteor(), "METEOR"),
|
19 |
+
(Rouge(), "ROUGE_L")
|
20 |
+
]
|
21 |
+
eval_res = {}
|
22 |
+
# Compute score for each metric
|
23 |
+
for scorer, method in scorers:
|
24 |
+
try:
|
25 |
+
score, scores = scorer.compute_score(gts, res, verbose=0)
|
26 |
+
except TypeError:
|
27 |
+
score, scores = scorer.compute_score(gts, res)
|
28 |
+
if type(method) == list:
|
29 |
+
for sc, m in zip(score, method):
|
30 |
+
eval_res[m] = sc
|
31 |
+
else:
|
32 |
+
eval_res[method] = score
|
33 |
+
return eval_res
|
modules/optimizers.py
ADDED
@@ -0,0 +1,18 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
|
3 |
+
|
4 |
+
def build_optimizer(args, model):
|
5 |
+
ve_params = list(map(id, model.visual_extractor.parameters()))
|
6 |
+
ed_params = filter(lambda x: id(x) not in ve_params, model.parameters())
|
7 |
+
optimizer = getattr(torch.optim, args.optim)(
|
8 |
+
[{'params': model.visual_extractor.parameters(), 'lr': args.lr_ve},
|
9 |
+
{'params': ed_params, 'lr': args.lr_ed}],
|
10 |
+
weight_decay=args.weight_decay,
|
11 |
+
amsgrad=args.amsgrad
|
12 |
+
)
|
13 |
+
return optimizer
|
14 |
+
|
15 |
+
|
16 |
+
def build_lr_scheduler(args, optimizer):
|
17 |
+
lr_scheduler = getattr(torch.optim.lr_scheduler, args.lr_scheduler)(optimizer, args.step_size, args.gamma)
|
18 |
+
return lr_scheduler
|
modules/tester.py
ADDED
@@ -0,0 +1,144 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import logging
|
2 |
+
import os
|
3 |
+
from abc import abstractmethod
|
4 |
+
|
5 |
+
import cv2
|
6 |
+
import numpy as np
|
7 |
+
import pandas as pd
|
8 |
+
import spacy
|
9 |
+
import torch
|
10 |
+
from tqdm import tqdm
|
11 |
+
|
12 |
+
from modules.utils import generate_heatmap
|
13 |
+
|
14 |
+
|
15 |
+
class BaseTester(object):
|
16 |
+
def __init__(self, model, criterion, metric_ftns, args):
|
17 |
+
self.args = args
|
18 |
+
|
19 |
+
logging.basicConfig(format='%(asctime)s - %(levelname)s - %(name)s - %(message)s',
|
20 |
+
datefmt='%m/%d/%Y %H:%M:%S', level=logging.INFO)
|
21 |
+
self.logger = logging.getLogger(__name__)
|
22 |
+
|
23 |
+
# setup GPU device if available, move model into configured device
|
24 |
+
self.device, device_ids = self._prepare_device(args.n_gpu)
|
25 |
+
self.model = model.to(self.device)
|
26 |
+
if len(device_ids) > 1:
|
27 |
+
self.model = torch.nn.DataParallel(model, device_ids=device_ids)
|
28 |
+
|
29 |
+
self.criterion = criterion
|
30 |
+
self.metric_ftns = metric_ftns
|
31 |
+
|
32 |
+
self.epochs = self.args.epochs
|
33 |
+
self.save_dir = self.args.save_dir
|
34 |
+
if not os.path.exists(self.save_dir):
|
35 |
+
os.makedirs(self.save_dir)
|
36 |
+
|
37 |
+
self._load_checkpoint(args.load)
|
38 |
+
|
39 |
+
@abstractmethod
|
40 |
+
def test(self):
|
41 |
+
raise NotImplementedError
|
42 |
+
|
43 |
+
@abstractmethod
|
44 |
+
def plot(self):
|
45 |
+
raise NotImplementedError
|
46 |
+
|
47 |
+
def _prepare_device(self, n_gpu_use):
|
48 |
+
n_gpu = torch.cuda.device_count()
|
49 |
+
if n_gpu_use > 0 and n_gpu == 0:
|
50 |
+
self.logger.warning(
|
51 |
+
"Warning: There\'s no GPU available on this machine," "training will be performed on CPU.")
|
52 |
+
n_gpu_use = 0
|
53 |
+
if n_gpu_use > n_gpu:
|
54 |
+
self.logger.warning(
|
55 |
+
"Warning: The number of GPU\'s configured to use is {}, but only {} are available " "on this machine.".format(
|
56 |
+
n_gpu_use, n_gpu))
|
57 |
+
n_gpu_use = n_gpu
|
58 |
+
device = torch.device('cuda:0' if n_gpu_use > 0 else 'cpu')
|
59 |
+
list_ids = list(range(n_gpu_use))
|
60 |
+
return device, list_ids
|
61 |
+
|
62 |
+
def _load_checkpoint(self, load_path):
|
63 |
+
load_path = str(load_path)
|
64 |
+
self.logger.info("Loading checkpoint: {} ...".format(load_path))
|
65 |
+
checkpoint = torch.load(load_path)
|
66 |
+
self.model.load_state_dict(checkpoint)
|
67 |
+
|
68 |
+
|
69 |
+
class Tester(BaseTester):
|
70 |
+
def __init__(self, model, criterion, metric_ftns, args, test_dataloader):
|
71 |
+
super(Tester, self).__init__(model, criterion, metric_ftns, args)
|
72 |
+
self.test_dataloader = test_dataloader
|
73 |
+
|
74 |
+
def test(self):
|
75 |
+
self.logger.info('Start to evaluate in the test set.')
|
76 |
+
self.model.eval()
|
77 |
+
log = dict()
|
78 |
+
with torch.no_grad():
|
79 |
+
test_gts, test_res = [], []
|
80 |
+
for batch_idx, (images_id, images, reports_ids, reports_masks, align_ids, align_masks) in enumerate(self.test_dataloader):
|
81 |
+
images, reports_ids, reports_masks, align_ids, align_masks = images.to(self.device), reports_ids.to(self.device), \
|
82 |
+
reports_masks.to(self.device), align_ids.to(self.device), align_masks.to(self.device)
|
83 |
+
output = self.model(reports_ids, align_ids, align_masks, images, mode='sample')
|
84 |
+
reports = self.model.tokenizer.decode_batch(output.cpu().numpy())
|
85 |
+
ground_truths = self.model.tokenizer.decode_batch(reports_ids[:, 1:].cpu().numpy())
|
86 |
+
test_res.extend(reports)
|
87 |
+
test_gts.extend(ground_truths)
|
88 |
+
|
89 |
+
test_met = self.metric_ftns({i: [gt] for i, gt in enumerate(test_gts)},
|
90 |
+
{i: [re] for i, re in enumerate(test_res)})
|
91 |
+
log.update(**{'test_' + k: v for k, v in test_met.items()})
|
92 |
+
print(log)
|
93 |
+
|
94 |
+
test_res, test_gts = pd.DataFrame(test_res), pd.DataFrame(test_gts)
|
95 |
+
test_res.to_csv(os.path.join(self.save_dir, "res.csv"), index=False, header=False)
|
96 |
+
test_gts.to_csv(os.path.join(self.save_dir, "gts.csv"), index=False, header=False)
|
97 |
+
|
98 |
+
return log
|
99 |
+
|
100 |
+
def plot(self):
|
101 |
+
assert self.args.batch_size == 1 and self.args.beam_size == 1
|
102 |
+
self.logger.info('Start to plot attention weights in the test set.')
|
103 |
+
os.makedirs(os.path.join(self.save_dir, "attentions"), exist_ok=True)
|
104 |
+
os.makedirs(os.path.join(self.save_dir, "attentions_entities"), exist_ok=True)
|
105 |
+
ner = spacy.load("en_core_sci_sm")
|
106 |
+
mean = torch.tensor((0.485, 0.456, 0.406))
|
107 |
+
std = torch.tensor((0.229, 0.224, 0.225))
|
108 |
+
mean = mean[:, None, None]
|
109 |
+
std = std[:, None, None]
|
110 |
+
|
111 |
+
self.model.eval()
|
112 |
+
with torch.no_grad():
|
113 |
+
for batch_idx, (images_id, images, reports_ids, reports_masks) in tqdm(enumerate(self.test_dataloader)):
|
114 |
+
images, reports_ids, reports_masks = images.to(self.device), reports_ids.to(
|
115 |
+
self.device), reports_masks.to(self.device)
|
116 |
+
output, _ = self.model(images, mode='sample')
|
117 |
+
image = torch.clamp((images[0].cpu() * std + mean) * 255, 0, 255).int().cpu().numpy()
|
118 |
+
report = self.model.tokenizer.decode_batch(output.cpu().numpy())[0].split()
|
119 |
+
|
120 |
+
char2word = [idx for word_idx, word in enumerate(report) for idx in [word_idx] * (len(word) + 1)][:-1]
|
121 |
+
|
122 |
+
attention_weights = self.model.encoder_decoder.attention_weights[:-1]
|
123 |
+
assert len(attention_weights) == len(report)
|
124 |
+
for word_idx, (attns, word) in enumerate(zip(attention_weights, report)):
|
125 |
+
for layer_idx, attn in enumerate(attns):
|
126 |
+
os.makedirs(os.path.join(self.save_dir, "attentions", "{:04d}".format(batch_idx),
|
127 |
+
"layer_{}".format(layer_idx)), exist_ok=True)
|
128 |
+
|
129 |
+
heatmap = generate_heatmap(image, attn.mean(1).squeeze())
|
130 |
+
cv2.imwrite(os.path.join(self.save_dir, "attentions", "{:04d}".format(batch_idx),
|
131 |
+
"layer_{}".format(layer_idx), "{:04d}_{}.png".format(word_idx, word)),
|
132 |
+
heatmap)
|
133 |
+
|
134 |
+
for ne_idx, ne in enumerate(ner(" ".join(report)).ents):
|
135 |
+
for layer_idx in range(len(attention_weights[0])):
|
136 |
+
os.makedirs(os.path.join(self.save_dir, "attentions_entities", "{:04d}".format(batch_idx),
|
137 |
+
"layer_{}".format(layer_idx)), exist_ok=True)
|
138 |
+
attn = [attns[layer_idx] for attns in
|
139 |
+
attention_weights[char2word[ne.start_char]:char2word[ne.end_char] + 1]]
|
140 |
+
attn = np.concatenate(attn, axis=2)
|
141 |
+
heatmap = generate_heatmap(image, attn.mean(1).mean(1).squeeze())
|
142 |
+
cv2.imwrite(os.path.join(self.save_dir, "attentions_entities", "{:04d}".format(batch_idx),
|
143 |
+
"layer_{}".format(layer_idx), "{:04d}_{}.png".format(ne_idx, ne)),
|
144 |
+
heatmap)
|
modules/tokenizers.py
ADDED
@@ -0,0 +1,95 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import json
|
2 |
+
import re
|
3 |
+
from collections import Counter
|
4 |
+
|
5 |
+
|
6 |
+
class Tokenizer(object):
|
7 |
+
def __init__(self, args):
|
8 |
+
self.ann_path = args.ann_path
|
9 |
+
self.threshold = args.threshold
|
10 |
+
self.dataset_name = args.dataset_name
|
11 |
+
if self.dataset_name == 'iu_xray':
|
12 |
+
self.clean_report = self.clean_report_iu_xray
|
13 |
+
else:
|
14 |
+
self.clean_report = self.clean_report_mimic_cxr
|
15 |
+
self.ann = json.loads(open(self.ann_path, 'r').read())
|
16 |
+
self.token2idx, self.idx2token = self.create_vocabulary()
|
17 |
+
|
18 |
+
def create_vocabulary(self):
|
19 |
+
total_tokens = []
|
20 |
+
|
21 |
+
for example in self.ann['train']:
|
22 |
+
tokens = self.clean_report(example['report']).split()
|
23 |
+
for token in tokens:
|
24 |
+
total_tokens.append(token)
|
25 |
+
|
26 |
+
counter = Counter(total_tokens)
|
27 |
+
vocab = [k for k, v in counter.items() if v >= self.threshold] + ['<unk>']
|
28 |
+
vocab.sort()
|
29 |
+
token2idx, idx2token = {}, {}
|
30 |
+
for idx, token in enumerate(vocab):
|
31 |
+
token2idx[token] = idx + 1
|
32 |
+
idx2token[idx + 1] = token
|
33 |
+
return token2idx, idx2token
|
34 |
+
|
35 |
+
def clean_report_iu_xray(self, report):
|
36 |
+
report_cleaner = lambda t: t.replace('..', '.').replace('..', '.').replace('..', '.').replace('1. ', '') \
|
37 |
+
.replace('. 2. ', '. ').replace('. 3. ', '. ').replace('. 4. ', '. ').replace('. 5. ', '. ') \
|
38 |
+
.replace(' 2. ', '. ').replace(' 3. ', '. ').replace(' 4. ', '. ').replace(' 5. ', '. ') \
|
39 |
+
.strip().lower().split('. ')
|
40 |
+
sent_cleaner = lambda t: re.sub('[.,?;*!%^&_+():-\[\]{}]', '', t.replace('"', '').replace('/', '').
|
41 |
+
replace('\\', '').replace("'", '').strip().lower())
|
42 |
+
tokens = [sent_cleaner(sent) for sent in report_cleaner(report) if sent_cleaner(sent) != []]
|
43 |
+
report = ' . '.join(tokens) + ' .'
|
44 |
+
return report
|
45 |
+
|
46 |
+
def clean_report_mimic_cxr(self, report):
|
47 |
+
report_cleaner = lambda t: t.replace('\n', ' ').replace('__', '_').replace('__', '_').replace('__', '_') \
|
48 |
+
.replace('__', '_').replace('__', '_').replace('__', '_').replace('__', '_').replace(' ', ' ') \
|
49 |
+
.replace(' ', ' ').replace(' ', ' ').replace(' ', ' ').replace(' ', ' ').replace(' ', ' ') \
|
50 |
+
.replace('..', '.').replace('..', '.').replace('..', '.').replace('..', '.').replace('..', '.') \
|
51 |
+
.replace('..', '.').replace('..', '.').replace('..', '.').replace('1. ', '').replace('. 2. ', '. ') \
|
52 |
+
.replace('. 3. ', '. ').replace('. 4. ', '. ').replace('. 5. ', '. ').replace(' 2. ', '. ') \
|
53 |
+
.replace(' 3. ', '. ').replace(' 4. ', '. ').replace(' 5. ', '. ') \
|
54 |
+
.strip().lower().split('. ')
|
55 |
+
sent_cleaner = lambda t: re.sub('[.,?;*!%^&_+():-\[\]{}]', '', t.replace('"', '').replace('/', '')
|
56 |
+
.replace('\\', '').replace("'", '').strip().lower())
|
57 |
+
tokens = [sent_cleaner(sent) for sent in report_cleaner(report) if sent_cleaner(sent) != []]
|
58 |
+
report = ' . '.join(tokens) + ' .'
|
59 |
+
return report
|
60 |
+
|
61 |
+
def get_token_by_id(self, id):
|
62 |
+
return self.idx2token[id]
|
63 |
+
|
64 |
+
def get_id_by_token(self, token):
|
65 |
+
if token not in self.token2idx:
|
66 |
+
return self.token2idx['<unk>']
|
67 |
+
return self.token2idx[token]
|
68 |
+
|
69 |
+
def get_vocab_size(self):
|
70 |
+
return len(self.token2idx)
|
71 |
+
|
72 |
+
def __call__(self, report):
|
73 |
+
tokens = self.clean_report(report).split()
|
74 |
+
ids = []
|
75 |
+
for token in tokens:
|
76 |
+
ids.append(self.get_id_by_token(token))
|
77 |
+
ids = [0] + ids + [0]
|
78 |
+
return ids
|
79 |
+
|
80 |
+
def decode(self, ids):
|
81 |
+
txt = ''
|
82 |
+
for i, idx in enumerate(ids):
|
83 |
+
if idx > 0:
|
84 |
+
if i >= 1:
|
85 |
+
txt += ' '
|
86 |
+
txt += self.idx2token[idx]
|
87 |
+
else:
|
88 |
+
break
|
89 |
+
return txt
|
90 |
+
|
91 |
+
def decode_batch(self, ids_batch):
|
92 |
+
out = []
|
93 |
+
for ids in ids_batch:
|
94 |
+
out.append(self.decode(ids))
|
95 |
+
return out
|
modules/trainer.py
ADDED
@@ -0,0 +1,255 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
from abc import abstractmethod
|
3 |
+
import json
|
4 |
+
import time
|
5 |
+
import torch
|
6 |
+
import pandas as pd
|
7 |
+
from numpy import inf
|
8 |
+
|
9 |
+
|
10 |
+
class BaseTrainer(object):
|
11 |
+
def __init__(self, model, criterion, metric_ftns, optimizer, args):
|
12 |
+
self.args = args
|
13 |
+
|
14 |
+
# setup GPU device if available, move model into configured device
|
15 |
+
self.device, device_ids = self._prepare_device(args.n_gpu)
|
16 |
+
self.model = model.to(self.device)
|
17 |
+
if len(device_ids) > 1:
|
18 |
+
self.model = torch.nn.DataParallel(model, device_ids=device_ids)
|
19 |
+
|
20 |
+
self.criterion = criterion
|
21 |
+
self.metric_ftns = metric_ftns
|
22 |
+
self.optimizer = optimizer
|
23 |
+
|
24 |
+
self.epochs = self.args.epochs
|
25 |
+
self.save_period = self.args.save_period
|
26 |
+
|
27 |
+
self.mnt_mode = args.monitor_mode
|
28 |
+
self.mnt_metric = 'val_' + args.monitor_metric
|
29 |
+
self.mnt_metric_test = 'test_' + args.monitor_metric
|
30 |
+
assert self.mnt_mode in ['min', 'max']
|
31 |
+
|
32 |
+
self.mnt_best = inf if self.mnt_mode == 'min' else -inf
|
33 |
+
self.early_stop = getattr(self.args, 'early_stop', inf)
|
34 |
+
|
35 |
+
self.start_epoch = 1
|
36 |
+
self.checkpoint_dir = args.save_dir
|
37 |
+
|
38 |
+
if not os.path.exists(self.checkpoint_dir):
|
39 |
+
os.makedirs(self.checkpoint_dir)
|
40 |
+
|
41 |
+
if args.resume is not None:
|
42 |
+
self._resume_checkpoint(args.resume)
|
43 |
+
|
44 |
+
self.best_recorder = {'val': {self.mnt_metric: self.mnt_best},
|
45 |
+
'test': {self.mnt_metric_test: self.mnt_best}}
|
46 |
+
|
47 |
+
@abstractmethod
|
48 |
+
def _train_epoch(self, epoch):
|
49 |
+
raise NotImplementedError
|
50 |
+
|
51 |
+
def train(self):
|
52 |
+
not_improved_count = 0
|
53 |
+
for epoch in range(self.start_epoch, self.epochs + 1):
|
54 |
+
result = self._train_epoch(epoch)
|
55 |
+
|
56 |
+
# save logged informations into log dict
|
57 |
+
log = {'epoch': epoch}
|
58 |
+
log.update(result)
|
59 |
+
self._record_best(log)
|
60 |
+
|
61 |
+
# print logged informations to the screen
|
62 |
+
for key, value in log.items():
|
63 |
+
print('\t{:15s}: {}'.format(str(key), value))
|
64 |
+
|
65 |
+
# evaluate model performance according to configured metric, save best checkpoint as model_best
|
66 |
+
best = False
|
67 |
+
if self.mnt_mode != 'off':
|
68 |
+
try:
|
69 |
+
# check whether model performance improved or not, according to specified metric(mnt_metric)
|
70 |
+
improved = (self.mnt_mode == 'min' and log[self.mnt_metric] <= self.mnt_best) or \
|
71 |
+
(self.mnt_mode == 'max' and log[self.mnt_metric] >= self.mnt_best)
|
72 |
+
except KeyError:
|
73 |
+
print("Warning: Metric '{}' is not found. " "Model performance monitoring is disabled.".format(
|
74 |
+
self.mnt_metric))
|
75 |
+
self.mnt_mode = 'off'
|
76 |
+
improved = False
|
77 |
+
|
78 |
+
if improved:
|
79 |
+
self.mnt_best = log[self.mnt_metric]
|
80 |
+
not_improved_count = 0
|
81 |
+
best = True
|
82 |
+
else:
|
83 |
+
not_improved_count += 1
|
84 |
+
|
85 |
+
if not_improved_count > self.early_stop:
|
86 |
+
print("Validation performance didn\'t improve for {} epochs. " "Training stops.".format(
|
87 |
+
self.early_stop))
|
88 |
+
break
|
89 |
+
|
90 |
+
if epoch % self.save_period == 0:
|
91 |
+
self._save_checkpoint(epoch, save_best=best)
|
92 |
+
self._print_best()
|
93 |
+
self._print_best_to_file()
|
94 |
+
|
95 |
+
def _print_best_to_file(self):
|
96 |
+
crt_time = time.asctime(time.localtime(time.time()))
|
97 |
+
self.best_recorder['val']['time'] = crt_time
|
98 |
+
self.best_recorder['test']['time'] = crt_time
|
99 |
+
self.best_recorder['val']['seed'] = self.args.seed
|
100 |
+
self.best_recorder['test']['seed'] = self.args.seed
|
101 |
+
self.best_recorder['val']['best_model_from'] = 'val'
|
102 |
+
self.best_recorder['test']['best_model_from'] = 'test'
|
103 |
+
|
104 |
+
if not os.path.exists(self.args.record_dir):
|
105 |
+
os.makedirs(self.args.record_dir)
|
106 |
+
record_path = os.path.join(self.args.record_dir, self.args.dataset_name+'.csv')
|
107 |
+
if not os.path.exists(record_path):
|
108 |
+
record_table = pd.DataFrame()
|
109 |
+
else:
|
110 |
+
record_table = pd.read_csv(record_path)
|
111 |
+
record_table = record_table.append(self.best_recorder['val'], ignore_index=True)
|
112 |
+
record_table = record_table.append(self.best_recorder['test'], ignore_index=True)
|
113 |
+
record_table.to_csv(record_path, index=False)
|
114 |
+
|
115 |
+
def _prepare_device(self, n_gpu_use):
|
116 |
+
n_gpu = torch.cuda.device_count()
|
117 |
+
if n_gpu_use > 0 and n_gpu == 0:
|
118 |
+
print("Warning: There\'s no GPU available on this machine," "training will be performed on CPU.")
|
119 |
+
n_gpu_use = 0
|
120 |
+
if n_gpu_use > n_gpu:
|
121 |
+
print(
|
122 |
+
"Warning: The number of GPU\'s configured to use is {}, but only {} are available " "on this machine.".format(
|
123 |
+
n_gpu_use, n_gpu))
|
124 |
+
n_gpu_use = n_gpu
|
125 |
+
device = torch.device('cuda:0' if n_gpu_use > 0 else 'cpu')
|
126 |
+
list_ids = list(range(n_gpu_use))
|
127 |
+
return device, list_ids
|
128 |
+
|
129 |
+
def _save_checkpoint(self, epoch, save_best=False):
|
130 |
+
state = {
|
131 |
+
'epoch': epoch,
|
132 |
+
'state_dict': self.model.state_dict(),
|
133 |
+
'optimizer': self.optimizer.state_dict(),
|
134 |
+
'monitor_best': self.mnt_best
|
135 |
+
}
|
136 |
+
filename = os.path.join(self.checkpoint_dir, 'current_checkpoint.pth')
|
137 |
+
torch.save(state, filename)
|
138 |
+
print("Saving checkpoint: {} ...".format(filename))
|
139 |
+
if save_best:
|
140 |
+
best_path = os.path.join(self.checkpoint_dir, 'model_best.pth')
|
141 |
+
torch.save(state, best_path)
|
142 |
+
print("Saving current best: model_best.pth ...")
|
143 |
+
|
144 |
+
def _resume_checkpoint(self, resume_path):
|
145 |
+
resume_path = str(resume_path)
|
146 |
+
print("Loading checkpoint: {} ...".format(resume_path))
|
147 |
+
checkpoint = torch.load(resume_path)
|
148 |
+
self.start_epoch = checkpoint['epoch'] + 1
|
149 |
+
self.mnt_best = checkpoint['monitor_best']
|
150 |
+
self.model.load_state_dict(checkpoint['state_dict'])
|
151 |
+
self.optimizer.load_state_dict(checkpoint['optimizer'])
|
152 |
+
|
153 |
+
print("Checkpoint loaded. Resume training from epoch {}".format(self.start_epoch))
|
154 |
+
|
155 |
+
def _record_best(self, log):
|
156 |
+
improved_val = (self.mnt_mode == 'min' and log[self.mnt_metric] <= self.best_recorder['val'][
|
157 |
+
self.mnt_metric]) or \
|
158 |
+
(self.mnt_mode == 'max' and log[self.mnt_metric] >= self.best_recorder['val'][self.mnt_metric])
|
159 |
+
if improved_val:
|
160 |
+
self.best_recorder['val'].update(log)
|
161 |
+
|
162 |
+
improved_test = (self.mnt_mode == 'min' and log[self.mnt_metric_test] <= self.best_recorder['test'][
|
163 |
+
self.mnt_metric_test]) or \
|
164 |
+
(self.mnt_mode == 'max' and log[self.mnt_metric_test] >= self.best_recorder['test'][
|
165 |
+
self.mnt_metric_test])
|
166 |
+
if improved_test:
|
167 |
+
self.best_recorder['test'].update(log)
|
168 |
+
|
169 |
+
def _print_best(self):
|
170 |
+
print('Best results (w.r.t {}) in validation set:'.format(self.args.monitor_metric))
|
171 |
+
for key, value in self.best_recorder['val'].items():
|
172 |
+
print('\t{:15s}: {}'.format(str(key), value))
|
173 |
+
|
174 |
+
print('Best results (w.r.t {}) in test set:'.format(self.args.monitor_metric))
|
175 |
+
for key, value in self.best_recorder['test'].items():
|
176 |
+
print('\t{:15s}: {}'.format(str(key), value))
|
177 |
+
|
178 |
+
|
179 |
+
if not os.path.exists('valreports/'):
|
180 |
+
os.makedirs('valreports/')
|
181 |
+
if not os.path.exists('testreports/'):
|
182 |
+
os.makedirs('testreports/')
|
183 |
+
|
184 |
+
class Trainer(BaseTrainer):
|
185 |
+
def __init__(self, model, criterion, metric_ftns, optimizer, args, lr_scheduler, train_dataloader, val_dataloader,
|
186 |
+
test_dataloader):
|
187 |
+
super(Trainer, self).__init__(model, criterion, metric_ftns, optimizer, args)
|
188 |
+
self.lr_scheduler = lr_scheduler
|
189 |
+
self.train_dataloader = train_dataloader
|
190 |
+
self.val_dataloader = val_dataloader
|
191 |
+
self.test_dataloader = test_dataloader
|
192 |
+
|
193 |
+
def _train_epoch(self, epoch):
|
194 |
+
|
195 |
+
train_loss = 0
|
196 |
+
self.model.train()
|
197 |
+
for batch_idx, (images_id, images, reports_ids, reports_masks) in enumerate(self.train_dataloader):
|
198 |
+
images, reports_ids, reports_masks = images.to(self.device), reports_ids.to(self.device), reports_masks.to(
|
199 |
+
self.device)
|
200 |
+
output = self.model(images, reports_ids, mode='train')
|
201 |
+
loss = self.criterion(output, reports_ids, reports_masks)
|
202 |
+
train_loss += loss.item()
|
203 |
+
self.optimizer.zero_grad()
|
204 |
+
loss.backward()
|
205 |
+
torch.nn.utils.clip_grad_value_(self.model.parameters(), 0.1)
|
206 |
+
self.optimizer.step()
|
207 |
+
log = {'train_loss': train_loss / len(self.train_dataloader)}
|
208 |
+
|
209 |
+
|
210 |
+
self.model.eval()
|
211 |
+
with torch.no_grad():
|
212 |
+
result_report_val = []
|
213 |
+
val_gts, val_res = [], []
|
214 |
+
for batch_idx, (images_id, images, reports_ids, reports_masks) in enumerate(self.val_dataloader):
|
215 |
+
images, reports_ids, reports_masks = images.to(self.device), reports_ids.to(
|
216 |
+
self.device), reports_masks.to(self.device)
|
217 |
+
output = self.model(images, mode='sample')
|
218 |
+
reports = self.model.tokenizer.decode_batch(output.cpu().numpy())
|
219 |
+
for i in range(reports_ids.shape[0]):
|
220 |
+
temp1 = {'reports_ids': images_id[i], 'reports': reports[i]}
|
221 |
+
result_report_val.append(temp1)
|
222 |
+
ground_truths = self.model.tokenizer.decode_batch(reports_ids[:, 1:].cpu().numpy())
|
223 |
+
val_res.extend(reports)
|
224 |
+
val_gts.extend(ground_truths)
|
225 |
+
val_met = self.metric_ftns({i: [gt] for i, gt in enumerate(val_gts)},
|
226 |
+
{i: [re] for i, re in enumerate(val_res)})
|
227 |
+
log.update(**{'val_' + k: v for k, v in val_met.items()})
|
228 |
+
resFileval = 'valreports/mixed-' + str(epoch) + '.json'
|
229 |
+
json.dump(result_report_val, open(resFileval, 'w'))
|
230 |
+
|
231 |
+
|
232 |
+
self.model.eval()
|
233 |
+
with torch.no_grad():
|
234 |
+
result_report_test = []
|
235 |
+
test_gts, test_res = [], []
|
236 |
+
for batch_idx, (images_id, images, reports_ids, reports_masks) in enumerate(self.test_dataloader):
|
237 |
+
images, reports_ids, reports_masks = images.to(self.device), reports_ids.to(
|
238 |
+
self.device), reports_masks.to(self.device)
|
239 |
+
output = self.model(images, mode='sample')
|
240 |
+
reports = self.model.tokenizer.decode_batch(output.cpu().numpy())
|
241 |
+
# print('reportsreportsreportsreports',images_id,reports)
|
242 |
+
for i in range(reports_ids.shape[0]):
|
243 |
+
temp = {'reports_ids': images_id[i], 'reports': reports[i]}
|
244 |
+
result_report_test.append(temp)
|
245 |
+
ground_truths = self.model.tokenizer.decode_batch(reports_ids[:, 1:].cpu().numpy())
|
246 |
+
test_res.extend(reports)
|
247 |
+
test_gts.extend(ground_truths)
|
248 |
+
test_met = self.metric_ftns({i: [gt] for i, gt in enumerate(test_gts)},
|
249 |
+
{i: [re] for i, re in enumerate(test_res)})
|
250 |
+
log.update(**{'test_' + k: v for k, v in test_met.items()})
|
251 |
+
resFiletest = 'testreports/mixed-' + str(epoch) + '.json'
|
252 |
+
json.dump(result_report_test, open(resFiletest, 'w'))
|
253 |
+
self.lr_scheduler.step()
|
254 |
+
|
255 |
+
return log
|
modules/utils.py
ADDED
@@ -0,0 +1,55 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
|
3 |
+
|
4 |
+
def penalty_builder(penalty_config):
|
5 |
+
if penalty_config == '':
|
6 |
+
return lambda x, y: y
|
7 |
+
pen_type, alpha = penalty_config.split('_')
|
8 |
+
alpha = float(alpha)
|
9 |
+
if pen_type == 'wu':
|
10 |
+
return lambda x, y: length_wu(x, y, alpha)
|
11 |
+
if pen_type == 'avg':
|
12 |
+
return lambda x, y: length_average(x, y, alpha)
|
13 |
+
|
14 |
+
|
15 |
+
def length_wu(length, logprobs, alpha=0.):
|
16 |
+
"""
|
17 |
+
NMT length re-ranking score from
|
18 |
+
"Google's Neural Machine Translation System" :cite:`wu2016google`.
|
19 |
+
"""
|
20 |
+
|
21 |
+
modifier = (((5 + length) ** alpha) /
|
22 |
+
((5 + 1) ** alpha))
|
23 |
+
return logprobs / modifier
|
24 |
+
|
25 |
+
|
26 |
+
def length_average(length, logprobs, alpha=0.):
|
27 |
+
"""
|
28 |
+
Returns the average probability of tokens in a sequence.
|
29 |
+
"""
|
30 |
+
return logprobs / length
|
31 |
+
|
32 |
+
|
33 |
+
def split_tensors(n, x):
|
34 |
+
if torch.is_tensor(x):
|
35 |
+
assert x.shape[0] % n == 0
|
36 |
+
x = x.reshape(x.shape[0] // n, n, *x.shape[1:]).unbind(1)
|
37 |
+
elif type(x) is list or type(x) is tuple:
|
38 |
+
x = [split_tensors(n, _) for _ in x]
|
39 |
+
elif x is None:
|
40 |
+
x = [None] * n
|
41 |
+
return x
|
42 |
+
|
43 |
+
|
44 |
+
def repeat_tensors(n, x):
|
45 |
+
"""
|
46 |
+
For a tensor of size Bx..., we repeat it n times, and make it Bnx...
|
47 |
+
For collections, do nested repeat
|
48 |
+
"""
|
49 |
+
if torch.is_tensor(x):
|
50 |
+
x = x.unsqueeze(1) # Bx1x...
|
51 |
+
x = x.expand(-1, n, *([-1] * len(x.shape[2:]))) # Bxnx...
|
52 |
+
x = x.reshape(x.shape[0] * n, *x.shape[2:]) # Bnx...
|
53 |
+
elif type(x) is list or type(x) is tuple:
|
54 |
+
x = [repeat_tensors(n, _) for _ in x]
|
55 |
+
return x
|
modules/visual_extractor.py
ADDED
@@ -0,0 +1,53 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
|
3 |
+
os.environ['HF_ENDPOINT'] = 'https://hf-mirror.com'
|
4 |
+
|
5 |
+
from medclip import MedCLIPModel, MedCLIPVisionModelViT
|
6 |
+
from medclip import MedCLIPProcessor
|
7 |
+
from PIL import Image
|
8 |
+
import torch
|
9 |
+
import torch.nn as nn
|
10 |
+
import torchvision.models as models
|
11 |
+
import torch.nn.functional as F
|
12 |
+
|
13 |
+
class VisualExtractor(nn.Module):
|
14 |
+
# prepare for the demo image and text
|
15 |
+
def __init__(self, args):
|
16 |
+
super(VisualExtractor, self).__init__()
|
17 |
+
self.model = MedCLIPModel(vision_cls=MedCLIPVisionModelViT)
|
18 |
+
self.model.from_pretrained()
|
19 |
+
self.model.cuda()
|
20 |
+
self.processor = MedCLIPProcessor()
|
21 |
+
with torch.no_grad():
|
22 |
+
self.prompt = torch.load('prompt/prompt.pth')
|
23 |
+
|
24 |
+
|
25 |
+
def forward(self, images):
|
26 |
+
a=[]
|
27 |
+
for i in images:
|
28 |
+
inputs = self.processor( text="lungs",images=i,return_tensors="pt",padding=True)
|
29 |
+
outputs = self.model(**inputs)
|
30 |
+
feats = outputs['img_embeds']
|
31 |
+
a.append(feats)
|
32 |
+
batch_feats = torch.stack(a, dim=0)
|
33 |
+
|
34 |
+
ha = []
|
35 |
+
for i in range(batch_feats.shape[0]):
|
36 |
+
b = batch_feats[i].unsqueeze(1)
|
37 |
+
b = b.repeat(self.prompt.shape[0], 1, 1).transpose(-2, -1)
|
38 |
+
c_t = torch.bmm(self.prompt, b)
|
39 |
+
c_t = c_t.float()
|
40 |
+
alpha = F.softmax(c_t)
|
41 |
+
aa = alpha * self.prompt
|
42 |
+
sum_a = aa.sum(axis=0)
|
43 |
+
ha.append(sum_a)
|
44 |
+
featsem = torch.stack(ha, dim=0)
|
45 |
+
|
46 |
+
feats = torch.cat((featsem, batch_feats), dim=2)
|
47 |
+
|
48 |
+
patch_feats = feats.repeat(1, 49, 1)
|
49 |
+
batch_feats1 = feats.squeeze(1)
|
50 |
+
avg_feats = batch_feats1
|
51 |
+
|
52 |
+
|
53 |
+
return patch_feats, avg_feats
|
prompt/prompt.pth
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:4b03692f5ba61e9d50d10556cdbb724ed6249668873bad099bc6548af618a7d0
|
3 |
+
size 20480747
|
pycocoevalcap/README.md
ADDED
@@ -0,0 +1,23 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
Microsoft COCO Caption Evaluation Tools <br />
|
2 |
+
---
|
3 |
+
|
4 |
+
Modified the code to work with Python 3. <br />
|
5 |
+
|
6 |
+
### Requirements
|
7 |
+
* Python 3.x
|
8 |
+
* Java 1.8
|
9 |
+
* pycocotools
|
10 |
+
|
11 |
+
---
|
12 |
+
|
13 |
+
### Tested on
|
14 |
+
* Windows 10, Python 3.5.
|
15 |
+
|
16 |
+
---
|
17 |
+
### To fix Windows JVM memory error: <br />
|
18 |
+
Add the following in System Variables <br />
|
19 |
+
Variable name : _JAVA_OPTIONS <br />
|
20 |
+
Variable value : -Xmx1024M <br />
|
21 |
+
|
22 |
+
---
|
23 |
+
Original code : https://github.com/tylin/coco-caption <br />
|
pycocoevalcap/__init__.py
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
__author__ = 'tylin'
|
pycocoevalcap/bleu/LICENSE
ADDED
@@ -0,0 +1,19 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
Copyright (c) 2015 Xinlei Chen, Hao Fang, Tsung-Yi Lin, and Ramakrishna Vedantam
|
2 |
+
|
3 |
+
Permission is hereby granted, free of charge, to any person obtaining a copy
|
4 |
+
of this software and associated documentation files (the "Software"), to deal
|
5 |
+
in the Software without restriction, including without limitation the rights
|
6 |
+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
7 |
+
copies of the Software, and to permit persons to whom the Software is
|
8 |
+
furnished to do so, subject to the following conditions:
|
9 |
+
|
10 |
+
The above copyright notice and this permission notice shall be included in
|
11 |
+
all copies or substantial portions of the Software.
|
12 |
+
|
13 |
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
14 |
+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
15 |
+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
16 |
+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
17 |
+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
18 |
+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
|
19 |
+
THE SOFTWARE.
|
pycocoevalcap/bleu/__init__.py
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
__author__ = 'tylin'
|
pycocoevalcap/bleu/bleu.py
ADDED
@@ -0,0 +1,57 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python
|
2 |
+
#
|
3 |
+
# File Name : bleu.py
|
4 |
+
#
|
5 |
+
# Description : Wrapper for BLEU scorer.
|
6 |
+
#
|
7 |
+
# Creation Date : 06-01-2015
|
8 |
+
# Last Modified : Thu 19 Mar 2015 09:13:28 PM PDT
|
9 |
+
# Authors : Hao Fang <hfang@uw.edu> and Tsung-Yi Lin <tl483@cornell.edu>
|
10 |
+
|
11 |
+
# Last modified : Wed 22 May 2019 08:10:00 PM EDT
|
12 |
+
# By Sabarish Sivanath
|
13 |
+
# To support Python 3
|
14 |
+
|
15 |
+
from .bleu_scorer import BleuScorer
|
16 |
+
|
17 |
+
|
18 |
+
class Bleu:
|
19 |
+
def __init__(self, n=4):
|
20 |
+
# default compute Blue score up to 4
|
21 |
+
self._n = n
|
22 |
+
self._hypo_for_image = {}
|
23 |
+
self.ref_for_image = {}
|
24 |
+
|
25 |
+
def compute_score(self, gts, res, score_option = 'closest', verbose = 1):
|
26 |
+
'''
|
27 |
+
Inputs:
|
28 |
+
gts - ground truths
|
29 |
+
res - predictions
|
30 |
+
score_option - {shortest, closest, average}
|
31 |
+
verbose - 1 or 0
|
32 |
+
Outputs:
|
33 |
+
Blue scores
|
34 |
+
'''
|
35 |
+
assert(gts.keys() == res.keys())
|
36 |
+
imgIds = gts.keys()
|
37 |
+
|
38 |
+
bleu_scorer = BleuScorer(n=self._n)
|
39 |
+
for id in imgIds:
|
40 |
+
hypo = res[id]
|
41 |
+
ref = gts[id]
|
42 |
+
|
43 |
+
# Sanity check.
|
44 |
+
assert(type(hypo) is list)
|
45 |
+
assert(len(hypo) == 1)
|
46 |
+
assert(type(ref) is list)
|
47 |
+
#assert(len(ref) >= 1)
|
48 |
+
|
49 |
+
bleu_scorer += (hypo[0], ref)
|
50 |
+
|
51 |
+
score, scores = bleu_scorer.compute_score(option = score_option, verbose =verbose)
|
52 |
+
|
53 |
+
# return (bleu, bleu_info)
|
54 |
+
return score, scores
|
55 |
+
|
56 |
+
def method(self):
|
57 |
+
return "Bleu"
|
pycocoevalcap/bleu/bleu_scorer.py
ADDED
@@ -0,0 +1,268 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# bleu_scorer.py
|
2 |
+
# David Chiang <chiang@isi.edu>
|
3 |
+
|
4 |
+
# Copyright (c) 2004-2006 University of Maryland. All rights
|
5 |
+
# reserved. Do not redistribute without permission from the
|
6 |
+
# author. Not for commercial use.
|
7 |
+
|
8 |
+
# Modified by:
|
9 |
+
# Hao Fang <hfang@uw.edu>
|
10 |
+
# Tsung-Yi Lin <tl483@cornell.edu>
|
11 |
+
|
12 |
+
# Last modified : Wed 22 May 2019 08:10:00 PM EDT
|
13 |
+
# By Sabarish Sivanath
|
14 |
+
# To support Python 3
|
15 |
+
|
16 |
+
'''Provides:
|
17 |
+
cook_refs(refs, n=4): Transform a list of reference sentences as strings into a form usable by cook_test().
|
18 |
+
cook_test(test, refs, n=4): Transform a test sentence as a string (together with the cooked reference sentences) into a form usable by score_cooked().
|
19 |
+
'''
|
20 |
+
|
21 |
+
import copy
|
22 |
+
import sys, math, re
|
23 |
+
from collections import defaultdict
|
24 |
+
|
25 |
+
def precook(s, n=4, out=False):
|
26 |
+
"""Takes a string as input and returns an object that can be given to
|
27 |
+
either cook_refs or cook_test. This is optional: cook_refs and cook_test
|
28 |
+
can take string arguments as well."""
|
29 |
+
words = s.split()
|
30 |
+
counts = defaultdict(int)
|
31 |
+
for k in range(1,n+1):
|
32 |
+
for i in range(len(words)-k+1):
|
33 |
+
ngram = tuple(words[i:i+k])
|
34 |
+
counts[ngram] += 1
|
35 |
+
return (len(words), counts)
|
36 |
+
|
37 |
+
def cook_refs(refs, eff=None, n=4): ## lhuang: oracle will call with "average"
|
38 |
+
'''Takes a list of reference sentences for a single segment
|
39 |
+
and returns an object that encapsulates everything that BLEU
|
40 |
+
needs to know about them.'''
|
41 |
+
|
42 |
+
reflen = []
|
43 |
+
maxcounts = {}
|
44 |
+
for ref in refs:
|
45 |
+
rl, counts = precook(ref, n)
|
46 |
+
reflen.append(rl)
|
47 |
+
for (ngram,count) in counts.items():
|
48 |
+
maxcounts[ngram] = max(maxcounts.get(ngram,0), count)
|
49 |
+
|
50 |
+
# Calculate effective reference sentence length.
|
51 |
+
if eff == "shortest":
|
52 |
+
reflen = min(reflen)
|
53 |
+
elif eff == "average":
|
54 |
+
reflen = float(sum(reflen))/len(reflen)
|
55 |
+
|
56 |
+
## lhuang: N.B.: leave reflen computaiton to the very end!!
|
57 |
+
|
58 |
+
## lhuang: N.B.: in case of "closest", keep a list of reflens!! (bad design)
|
59 |
+
|
60 |
+
return (reflen, maxcounts)
|
61 |
+
|
62 |
+
def cook_test(test, refs , eff=None, n=4):
|
63 |
+
'''Takes a test sentence and returns an object that
|
64 |
+
encapsulates everything that BLEU needs to know about it.'''
|
65 |
+
|
66 |
+
reflen = refs[0]
|
67 |
+
refmaxcounts = refs[1]
|
68 |
+
|
69 |
+
testlen, counts = precook(test, n, True)
|
70 |
+
|
71 |
+
result = {}
|
72 |
+
|
73 |
+
# Calculate effective reference sentence length.
|
74 |
+
|
75 |
+
if eff == "closest":
|
76 |
+
result["reflen"] = min((abs(l-testlen), l) for l in reflen)[1]
|
77 |
+
else: ## i.e., "average" or "shortest" or None
|
78 |
+
result["reflen"] = reflen
|
79 |
+
|
80 |
+
result["testlen"] = testlen
|
81 |
+
|
82 |
+
result["guess"] = [max(0,testlen-k+1) for k in range(1,n+1)]
|
83 |
+
|
84 |
+
result['correct'] = [0]*n
|
85 |
+
for (ngram, count) in counts.items():
|
86 |
+
result["correct"][len(ngram)-1] += min(refmaxcounts.get(ngram,0), count)
|
87 |
+
|
88 |
+
return result
|
89 |
+
|
90 |
+
class BleuScorer(object):
|
91 |
+
"""Bleu scorer.
|
92 |
+
"""
|
93 |
+
|
94 |
+
__slots__ = "n", "crefs", "ctest", "_score", "_ratio", "_testlen", "_reflen", "special_reflen"
|
95 |
+
# special_reflen is used in oracle (proportional effective ref len for a node).
|
96 |
+
|
97 |
+
def copy(self):
|
98 |
+
''' copy the refs.'''
|
99 |
+
new = BleuScorer(n=self.n)
|
100 |
+
new.ctest = copy.copy(self.ctest)
|
101 |
+
new.crefs = copy.copy(self.crefs)
|
102 |
+
new._score = None
|
103 |
+
return new
|
104 |
+
|
105 |
+
def __init__(self, test=None, refs=None, n=4, special_reflen=None):
|
106 |
+
''' singular instance '''
|
107 |
+
|
108 |
+
self.n = n
|
109 |
+
self.crefs = []
|
110 |
+
self.ctest = []
|
111 |
+
self.cook_append(test, refs)
|
112 |
+
self.special_reflen = special_reflen
|
113 |
+
|
114 |
+
def cook_append(self, test, refs):
|
115 |
+
'''called by constructor and __iadd__ to avoid creating new instances.'''
|
116 |
+
|
117 |
+
if refs is not None:
|
118 |
+
self.crefs.append(cook_refs(refs))
|
119 |
+
if test is not None:
|
120 |
+
cooked_test = cook_test(test, self.crefs[-1])
|
121 |
+
self.ctest.append(cooked_test) ## N.B.: -1
|
122 |
+
else:
|
123 |
+
self.ctest.append(None) # lens of crefs and ctest have to match
|
124 |
+
|
125 |
+
self._score = None ## need to recompute
|
126 |
+
|
127 |
+
def ratio(self, option=None):
|
128 |
+
self.compute_score(option=option)
|
129 |
+
return self._ratio
|
130 |
+
|
131 |
+
def score_ratio(self, option=None):
|
132 |
+
'''return (bleu, len_ratio) pair'''
|
133 |
+
return (self.fscore(option=option), self.ratio(option=option))
|
134 |
+
|
135 |
+
def score_ratio_str(self, option=None):
|
136 |
+
return "%.4f (%.2f)" % self.score_ratio(option)
|
137 |
+
|
138 |
+
def reflen(self, option=None):
|
139 |
+
self.compute_score(option=option)
|
140 |
+
return self._reflen
|
141 |
+
|
142 |
+
def testlen(self, option=None):
|
143 |
+
self.compute_score(option=option)
|
144 |
+
return self._testlen
|
145 |
+
|
146 |
+
def retest(self, new_test):
|
147 |
+
if type(new_test) is str:
|
148 |
+
new_test = [new_test]
|
149 |
+
assert len(new_test) == len(self.crefs), new_test
|
150 |
+
self.ctest = []
|
151 |
+
for t, rs in zip(new_test, self.crefs):
|
152 |
+
self.ctest.append(cook_test(t, rs))
|
153 |
+
self._score = None
|
154 |
+
|
155 |
+
return self
|
156 |
+
|
157 |
+
def rescore(self, new_test):
|
158 |
+
''' replace test(s) with new test(s), and returns the new score.'''
|
159 |
+
|
160 |
+
return self.retest(new_test).compute_score()
|
161 |
+
|
162 |
+
def size(self):
|
163 |
+
assert len(self.crefs) == len(self.ctest), "refs/test mismatch! %d<>%d" % (len(self.crefs), len(self.ctest))
|
164 |
+
return len(self.crefs)
|
165 |
+
|
166 |
+
def __iadd__(self, other):
|
167 |
+
'''add an instance (e.g., from another sentence).'''
|
168 |
+
|
169 |
+
if type(other) is tuple:
|
170 |
+
## avoid creating new BleuScorer instances
|
171 |
+
self.cook_append(other[0], other[1])
|
172 |
+
else:
|
173 |
+
assert self.compatible(other), "incompatible BLEUs."
|
174 |
+
self.ctest.extend(other.ctest)
|
175 |
+
self.crefs.extend(other.crefs)
|
176 |
+
self._score = None ## need to recompute
|
177 |
+
|
178 |
+
return self
|
179 |
+
|
180 |
+
def compatible(self, other):
|
181 |
+
return isinstance(other, BleuScorer) and self.n == other.n
|
182 |
+
|
183 |
+
def single_reflen(self, option="average"):
|
184 |
+
return self._single_reflen(self.crefs[0][0], option)
|
185 |
+
|
186 |
+
def _single_reflen(self, reflens, option=None, testlen=None):
|
187 |
+
|
188 |
+
if option == "shortest":
|
189 |
+
reflen = min(reflens)
|
190 |
+
elif option == "average":
|
191 |
+
reflen = float(sum(reflens))/len(reflens)
|
192 |
+
elif option == "closest":
|
193 |
+
reflen = min((abs(l-testlen), l) for l in reflens)[1]
|
194 |
+
else:
|
195 |
+
assert False, "unsupported reflen option %s" % option
|
196 |
+
|
197 |
+
return reflen
|
198 |
+
|
199 |
+
def recompute_score(self, option=None, verbose=0):
|
200 |
+
self._score = None
|
201 |
+
return self.compute_score(option, verbose)
|
202 |
+
|
203 |
+
def compute_score(self, option=None, verbose=0):
|
204 |
+
n = self.n
|
205 |
+
small = 1e-9
|
206 |
+
tiny = 1e-15 ## so that if guess is 0 still return 0
|
207 |
+
bleu_list = [[] for _ in range(n)]
|
208 |
+
|
209 |
+
if self._score is not None:
|
210 |
+
return self._score
|
211 |
+
|
212 |
+
if option is None:
|
213 |
+
option = "average" if len(self.crefs) == 1 else "closest"
|
214 |
+
|
215 |
+
self._testlen = 0
|
216 |
+
self._reflen = 0
|
217 |
+
totalcomps = {'testlen':0, 'reflen':0, 'guess':[0]*n, 'correct':[0]*n}
|
218 |
+
|
219 |
+
# for each sentence
|
220 |
+
for comps in self.ctest:
|
221 |
+
testlen = comps['testlen']
|
222 |
+
self._testlen += testlen
|
223 |
+
|
224 |
+
if self.special_reflen is None: ## need computation
|
225 |
+
reflen = self._single_reflen(comps['reflen'], option, testlen)
|
226 |
+
else:
|
227 |
+
reflen = self.special_reflen
|
228 |
+
|
229 |
+
self._reflen += reflen
|
230 |
+
|
231 |
+
for key in ['guess','correct']:
|
232 |
+
for k in range(n):
|
233 |
+
totalcomps[key][k] += comps[key][k]
|
234 |
+
|
235 |
+
# append per image bleu score
|
236 |
+
bleu = 1.
|
237 |
+
for k in range(n):
|
238 |
+
bleu *= (float(comps['correct'][k]) + tiny) \
|
239 |
+
/(float(comps['guess'][k]) + small)
|
240 |
+
bleu_list[k].append(bleu ** (1./(k+1)))
|
241 |
+
ratio = (testlen + tiny) / (reflen + small) ## N.B.: avoid zero division
|
242 |
+
if ratio < 1:
|
243 |
+
for k in range(n):
|
244 |
+
bleu_list[k][-1] *= math.exp(1 - 1/ratio)
|
245 |
+
|
246 |
+
if verbose > 1:
|
247 |
+
print(comps, reflen)
|
248 |
+
|
249 |
+
totalcomps['reflen'] = self._reflen
|
250 |
+
totalcomps['testlen'] = self._testlen
|
251 |
+
|
252 |
+
bleus = []
|
253 |
+
bleu = 1.
|
254 |
+
for k in range(n):
|
255 |
+
bleu *= float(totalcomps['correct'][k] + tiny) \
|
256 |
+
/ (totalcomps['guess'][k] + small)
|
257 |
+
bleus.append(bleu ** (1./(k+1)))
|
258 |
+
ratio = (self._testlen + tiny) / (self._reflen + small) ## N.B.: avoid zero division
|
259 |
+
if ratio < 1:
|
260 |
+
for k in range(n):
|
261 |
+
bleus[k] *= math.exp(1 - 1/ratio)
|
262 |
+
|
263 |
+
if verbose > 0:
|
264 |
+
print(totalcomps)
|
265 |
+
print("ratio:", ratio)
|
266 |
+
|
267 |
+
self._score = bleus
|
268 |
+
return self._score, bleu_list
|
pycocoevalcap/cider/__init__.py
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
__author__ = 'tylin'
|
pycocoevalcap/cider/cider.py
ADDED
@@ -0,0 +1,55 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Filename: cider.py
|
2 |
+
#
|
3 |
+
# Description: Describes the class to compute the CIDEr (Consensus-Based Image Description Evaluation) Metric
|
4 |
+
# by Vedantam, Zitnick, and Parikh (http://arxiv.org/abs/1411.5726)
|
5 |
+
#
|
6 |
+
# Creation Date: Sun Feb 8 14:16:54 2015
|
7 |
+
#
|
8 |
+
# Authors: Ramakrishna Vedantam <vrama91@vt.edu> and Tsung-Yi Lin <tl483@cornell.edu>
|
9 |
+
|
10 |
+
|
11 |
+
from .cider_scorer import CiderScorer
|
12 |
+
import pdb
|
13 |
+
|
14 |
+
class Cider:
|
15 |
+
"""
|
16 |
+
Main Class to compute the CIDEr metric
|
17 |
+
|
18 |
+
"""
|
19 |
+
def __init__(self, test=None, refs=None, n=4, sigma=6.0):
|
20 |
+
# set cider to sum over 1 to 4-grams
|
21 |
+
self._n = n
|
22 |
+
# set the standard deviation parameter for gaussian penalty
|
23 |
+
self._sigma = sigma
|
24 |
+
|
25 |
+
def compute_score(self, gts, res):
|
26 |
+
"""
|
27 |
+
Main function to compute CIDEr score
|
28 |
+
:param hypo_for_image (dict) : dictionary with key <image> and value <tokenized hypothesis / candidate sentence>
|
29 |
+
ref_for_image (dict) : dictionary with key <image> and value <tokenized reference sentence>
|
30 |
+
:return: cider (float) : computed CIDEr score for the corpus
|
31 |
+
"""
|
32 |
+
|
33 |
+
assert(gts.keys() == res.keys())
|
34 |
+
imgIds = gts.keys()
|
35 |
+
|
36 |
+
cider_scorer = CiderScorer(n=self._n, sigma=self._sigma)
|
37 |
+
|
38 |
+
for id in imgIds:
|
39 |
+
hypo = res[id]
|
40 |
+
ref = gts[id]
|
41 |
+
|
42 |
+
# Sanity check.
|
43 |
+
assert(type(hypo) is list)
|
44 |
+
assert(len(hypo) == 1)
|
45 |
+
assert(type(ref) is list)
|
46 |
+
assert(len(ref) > 0)
|
47 |
+
|
48 |
+
cider_scorer += (hypo[0], ref)
|
49 |
+
|
50 |
+
(score, scores) = cider_scorer.compute_score()
|
51 |
+
|
52 |
+
return score, scores
|
53 |
+
|
54 |
+
def method(self):
|
55 |
+
return "CIDEr"
|
pycocoevalcap/cider/cider_scorer.py
ADDED
@@ -0,0 +1,197 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python
|
2 |
+
# Tsung-Yi Lin <tl483@cornell.edu>
|
3 |
+
# Ramakrishna Vedantam <vrama91@vt.edu>
|
4 |
+
|
5 |
+
|
6 |
+
# Last modified : Wed 22 May 2019 08:10:00 PM EDT
|
7 |
+
# By Sabarish Sivanath
|
8 |
+
# To support Python 3
|
9 |
+
|
10 |
+
import copy
|
11 |
+
from collections import defaultdict
|
12 |
+
import numpy as np
|
13 |
+
import pdb
|
14 |
+
import math
|
15 |
+
|
16 |
+
def precook(s, n=4, out=False):
|
17 |
+
"""
|
18 |
+
Takes a string as input and returns an object that can be given to
|
19 |
+
either cook_refs or cook_test. This is optional: cook_refs and cook_test
|
20 |
+
can take string arguments as well.
|
21 |
+
:param s: string : sentence to be converted into ngrams
|
22 |
+
:param n: int : number of ngrams for which representation is calculated
|
23 |
+
:return: term frequency vector for occuring ngrams
|
24 |
+
"""
|
25 |
+
words = s.split()
|
26 |
+
counts = defaultdict(int)
|
27 |
+
for k in range(1,n+1):
|
28 |
+
for i in range(len(words)-k+1):
|
29 |
+
ngram = tuple(words[i:i+k])
|
30 |
+
counts[ngram] += 1
|
31 |
+
return counts
|
32 |
+
|
33 |
+
def cook_refs(refs, n=4): ## lhuang: oracle will call with "average"
|
34 |
+
'''Takes a list of reference sentences for a single segment
|
35 |
+
and returns an object that encapsulates everything that BLEU
|
36 |
+
needs to know about them.
|
37 |
+
:param refs: list of string : reference sentences for some image
|
38 |
+
:param n: int : number of ngrams for which (ngram) representation is calculated
|
39 |
+
:return: result (list of dict)
|
40 |
+
'''
|
41 |
+
return [precook(ref, n) for ref in refs]
|
42 |
+
|
43 |
+
def cook_test(test, n=4):
|
44 |
+
'''Takes a test sentence and returns an object that
|
45 |
+
encapsulates everything that BLEU needs to know about it.
|
46 |
+
:param test: list of string : hypothesis sentence for some image
|
47 |
+
:param n: int : number of ngrams for which (ngram) representation is calculated
|
48 |
+
:return: result (dict)
|
49 |
+
'''
|
50 |
+
return precook(test, n, True)
|
51 |
+
|
52 |
+
class CiderScorer(object):
|
53 |
+
"""CIDEr scorer.
|
54 |
+
"""
|
55 |
+
|
56 |
+
def copy(self):
|
57 |
+
''' copy the refs.'''
|
58 |
+
new = CiderScorer(n=self.n)
|
59 |
+
new.ctest = copy.copy(self.ctest)
|
60 |
+
new.crefs = copy.copy(self.crefs)
|
61 |
+
return new
|
62 |
+
|
63 |
+
def __init__(self, test=None, refs=None, n=4, sigma=6.0):
|
64 |
+
''' singular instance '''
|
65 |
+
self.n = n
|
66 |
+
self.sigma = sigma
|
67 |
+
self.crefs = []
|
68 |
+
self.ctest = []
|
69 |
+
self.document_frequency = defaultdict(float)
|
70 |
+
self.cook_append(test, refs)
|
71 |
+
self.ref_len = None
|
72 |
+
|
73 |
+
def cook_append(self, test, refs):
|
74 |
+
'''called by constructor and __iadd__ to avoid creating new instances.'''
|
75 |
+
|
76 |
+
if refs is not None:
|
77 |
+
self.crefs.append(cook_refs(refs))
|
78 |
+
if test is not None:
|
79 |
+
self.ctest.append(cook_test(test)) ## N.B.: -1
|
80 |
+
else:
|
81 |
+
self.ctest.append(None) # lens of crefs and ctest have to match
|
82 |
+
|
83 |
+
def size(self):
|
84 |
+
assert len(self.crefs) == len(self.ctest), "refs/test mismatch! %d<>%d" % (len(self.crefs), len(self.ctest))
|
85 |
+
return len(self.crefs)
|
86 |
+
|
87 |
+
def __iadd__(self, other):
|
88 |
+
'''add an instance (e.g., from another sentence).'''
|
89 |
+
|
90 |
+
if type(other) is tuple:
|
91 |
+
## avoid creating new CiderScorer instances
|
92 |
+
self.cook_append(other[0], other[1])
|
93 |
+
else:
|
94 |
+
self.ctest.extend(other.ctest)
|
95 |
+
self.crefs.extend(other.crefs)
|
96 |
+
|
97 |
+
return self
|
98 |
+
def compute_doc_freq(self):
|
99 |
+
'''
|
100 |
+
Compute term frequency for reference data.
|
101 |
+
This will be used to compute idf (inverse document frequency later)
|
102 |
+
The term frequency is stored in the object
|
103 |
+
:return: None
|
104 |
+
'''
|
105 |
+
for refs in self.crefs:
|
106 |
+
# refs, k ref captions of one image
|
107 |
+
for ngram in set([ngram for ref in refs for (ngram,count) in ref.items()]):
|
108 |
+
self.document_frequency[ngram] += 1
|
109 |
+
# maxcounts[ngram] = max(maxcounts.get(ngram,0), count)
|
110 |
+
|
111 |
+
def compute_cider(self):
|
112 |
+
def counts2vec(cnts):
|
113 |
+
"""
|
114 |
+
Function maps counts of ngram to vector of tfidf weights.
|
115 |
+
The function returns vec, an array of dictionary that store mapping of n-gram and tf-idf weights.
|
116 |
+
The n-th entry of array denotes length of n-grams.
|
117 |
+
:param cnts:
|
118 |
+
:return: vec (array of dict), norm (array of float), length (int)
|
119 |
+
"""
|
120 |
+
vec = [defaultdict(float) for _ in range(self.n)]
|
121 |
+
length = 0
|
122 |
+
norm = [0.0 for _ in range(self.n)]
|
123 |
+
for (ngram,term_freq) in cnts.items():
|
124 |
+
# give word count 1 if it doesn't appear in reference corpus
|
125 |
+
df = np.log(max(1.0, self.document_frequency[ngram]))
|
126 |
+
# ngram index
|
127 |
+
n = len(ngram)-1
|
128 |
+
# tf (term_freq) * idf (precomputed idf) for n-grams
|
129 |
+
vec[n][ngram] = float(term_freq)*(self.ref_len - df)
|
130 |
+
# compute norm for the vector. the norm will be used for computing similarity
|
131 |
+
norm[n] += pow(vec[n][ngram], 2)
|
132 |
+
|
133 |
+
if n == 1:
|
134 |
+
length += term_freq
|
135 |
+
norm = [np.sqrt(n) for n in norm]
|
136 |
+
return vec, norm, length
|
137 |
+
|
138 |
+
def sim(vec_hyp, vec_ref, norm_hyp, norm_ref, length_hyp, length_ref):
|
139 |
+
'''
|
140 |
+
Compute the cosine similarity of two vectors.
|
141 |
+
:param vec_hyp: array of dictionary for vector corresponding to hypothesis
|
142 |
+
:param vec_ref: array of dictionary for vector corresponding to reference
|
143 |
+
:param norm_hyp: array of float for vector corresponding to hypothesis
|
144 |
+
:param norm_ref: array of float for vector corresponding to reference
|
145 |
+
:param length_hyp: int containing length of hypothesis
|
146 |
+
:param length_ref: int containing length of reference
|
147 |
+
:return: array of score for each n-grams cosine similarity
|
148 |
+
'''
|
149 |
+
delta = float(length_hyp - length_ref)
|
150 |
+
# measure consine similarity
|
151 |
+
val = np.array([0.0 for _ in range(self.n)])
|
152 |
+
for n in range(self.n):
|
153 |
+
# ngram
|
154 |
+
for (ngram,count) in vec_hyp[n].items():
|
155 |
+
# vrama91 : added clipping
|
156 |
+
val[n] += min(vec_hyp[n][ngram], vec_ref[n][ngram]) * vec_ref[n][ngram]
|
157 |
+
|
158 |
+
if (norm_hyp[n] != 0) and (norm_ref[n] != 0):
|
159 |
+
val[n] /= (norm_hyp[n]*norm_ref[n])
|
160 |
+
|
161 |
+
assert(not math.isnan(val[n]))
|
162 |
+
# vrama91: added a length based gaussian penalty
|
163 |
+
val[n] *= np.e**(-(delta**2)/(2*self.sigma**2))
|
164 |
+
return val
|
165 |
+
|
166 |
+
# compute log reference length
|
167 |
+
self.ref_len = np.log(float(len(self.crefs)))
|
168 |
+
|
169 |
+
scores = []
|
170 |
+
for test, refs in zip(self.ctest, self.crefs):
|
171 |
+
# compute vector for test captions
|
172 |
+
vec, norm, length = counts2vec(test)
|
173 |
+
# compute vector for ref captions
|
174 |
+
score = np.array([0.0 for _ in range(self.n)])
|
175 |
+
for ref in refs:
|
176 |
+
vec_ref, norm_ref, length_ref = counts2vec(ref)
|
177 |
+
score += sim(vec, vec_ref, norm, norm_ref, length, length_ref)
|
178 |
+
# change by vrama91 - mean of ngram scores, instead of sum
|
179 |
+
score_avg = np.mean(score)
|
180 |
+
# divide by number of references
|
181 |
+
score_avg /= len(refs)
|
182 |
+
# multiply score by 10
|
183 |
+
score_avg *= 10.0
|
184 |
+
# append score of an image to the score list
|
185 |
+
scores.append(score_avg)
|
186 |
+
return scores
|
187 |
+
|
188 |
+
def compute_score(self, option=None, verbose=0):
|
189 |
+
# compute idf
|
190 |
+
self.compute_doc_freq()
|
191 |
+
# assert to check document frequency
|
192 |
+
assert(len(self.ctest) >= max(self.document_frequency.values()))
|
193 |
+
# compute cider score
|
194 |
+
score = self.compute_cider()
|
195 |
+
# debug
|
196 |
+
# print score
|
197 |
+
return np.mean(np.array(score)), np.array(score)
|
pycocoevalcap/eval.py
ADDED
@@ -0,0 +1,74 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
__author__ = 'tylin'
|
2 |
+
from .tokenizer.ptbtokenizer import PTBTokenizer
|
3 |
+
from .bleu.bleu import Bleu
|
4 |
+
from .meteor.meteor import Meteor
|
5 |
+
from .rouge.rouge import Rouge
|
6 |
+
from .cider.cider import Cider
|
7 |
+
|
8 |
+
class COCOEvalCap:
|
9 |
+
def __init__(self, coco, cocoRes):
|
10 |
+
self.evalImgs = []
|
11 |
+
self.eval = {}
|
12 |
+
self.imgToEval = {}
|
13 |
+
self.coco = coco
|
14 |
+
self.cocoRes = cocoRes
|
15 |
+
self.params = {'image_id': cocoRes.getImgIds()}
|
16 |
+
|
17 |
+
def evaluate(self):
|
18 |
+
imgIds = self.params['image_id']
|
19 |
+
# imgIds = self.coco.getImgIds()
|
20 |
+
gts = {}
|
21 |
+
res = {}
|
22 |
+
for imgId in imgIds:
|
23 |
+
gts[imgId] = self.coco.imgToAnns[imgId]
|
24 |
+
res[imgId] = self.cocoRes.imgToAnns[imgId]
|
25 |
+
|
26 |
+
# =================================================
|
27 |
+
# Set up scorers
|
28 |
+
# =================================================
|
29 |
+
print('tokenization...')
|
30 |
+
tokenizer = PTBTokenizer()
|
31 |
+
gts = tokenizer.tokenize(gts)
|
32 |
+
res = tokenizer.tokenize(res)
|
33 |
+
|
34 |
+
# =================================================
|
35 |
+
# Set up scorers
|
36 |
+
# =================================================
|
37 |
+
print('setting up scorers...')
|
38 |
+
scorers = [
|
39 |
+
(Bleu(4), ["Bleu_1", "Bleu_2", "Bleu_3", "Bleu_4"]),
|
40 |
+
(Meteor(),"METEOR"),
|
41 |
+
(Rouge(), "ROUGE_L"),
|
42 |
+
(Cider(), "CIDEr")
|
43 |
+
]
|
44 |
+
|
45 |
+
# =================================================
|
46 |
+
# Compute scores
|
47 |
+
# =================================================
|
48 |
+
eval = {}
|
49 |
+
for scorer, method in scorers:
|
50 |
+
print('computing %s score...'%(scorer.method()))
|
51 |
+
score, scores = scorer.compute_score(gts, res)
|
52 |
+
if type(method) == list:
|
53 |
+
for sc, scs, m in zip(score, scores, method):
|
54 |
+
self.setEval(sc, m)
|
55 |
+
self.setImgToEvalImgs(scs, imgIds, m)
|
56 |
+
print("%s: %0.3f"%(m, sc))
|
57 |
+
else:
|
58 |
+
self.setEval(score, method)
|
59 |
+
self.setImgToEvalImgs(scores, imgIds, method)
|
60 |
+
print("%s: %0.3f"%(method, score))
|
61 |
+
self.setEvalImgs()
|
62 |
+
|
63 |
+
def setEval(self, score, method):
|
64 |
+
self.eval[method] = score
|
65 |
+
|
66 |
+
def setImgToEvalImgs(self, scores, imgIds, method):
|
67 |
+
for imgId, score in zip(imgIds, scores):
|
68 |
+
if not imgId in self.imgToEval:
|
69 |
+
self.imgToEval[imgId] = {}
|
70 |
+
self.imgToEval[imgId]["image_id"] = imgId
|
71 |
+
self.imgToEval[imgId][method] = score
|
72 |
+
|
73 |
+
def setEvalImgs(self):
|
74 |
+
self.evalImgs = [eval for imgId, eval in self.imgToEval.items()]
|
pycocoevalcap/license.txt
ADDED
@@ -0,0 +1,26 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
Copyright (c) 2015, Xinlei Chen, Hao Fang, Tsung-Yi Lin, and Ramakrishna Vedantam
|
2 |
+
All rights reserved.
|
3 |
+
|
4 |
+
Redistribution and use in source and binary forms, with or without
|
5 |
+
modification, are permitted provided that the following conditions are met:
|
6 |
+
|
7 |
+
1. Redistributions of source code must retain the above copyright notice, this
|
8 |
+
list of conditions and the following disclaimer.
|
9 |
+
2. Redistributions in binary form must reproduce the above copyright notice,
|
10 |
+
this list of conditions and the following disclaimer in the documentation
|
11 |
+
and/or other materials provided with the distribution.
|
12 |
+
|
13 |
+
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
|
14 |
+
ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
|
15 |
+
WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
16 |
+
DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR
|
17 |
+
ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
|
18 |
+
(INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
|
19 |
+
LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND
|
20 |
+
ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
|
21 |
+
(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
|
22 |
+
SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
23 |
+
|
24 |
+
The views and conclusions contained in the software and documentation are those
|
25 |
+
of the authors and should not be interpreted as representing official policies,
|
26 |
+
either expressed or implied, of the FreeBSD Project.
|
pycocoevalcap/meteor/__init__.py
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
from .meteor import *
|
pycocoevalcap/meteor/meteor-1.5.jar
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:1e57b4c72c0830ebe68558f1c799a624e96cbc1b6045c9f6330e26dcff6eafc2
|
3 |
+
size 6318693
|
pycocoevalcap/meteor/meteor.py
ADDED
@@ -0,0 +1,88 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python
|
2 |
+
|
3 |
+
# Python wrapper for METEOR implementation, by Xinlei Chen
|
4 |
+
# Acknowledge Michael Denkowski for the generous discussion and help
|
5 |
+
|
6 |
+
# Last modified : Wed 22 May 2019 08:10:00 PM EDT
|
7 |
+
# By Sabarish Sivanath
|
8 |
+
# To support Python 3
|
9 |
+
|
10 |
+
import os
|
11 |
+
import sys
|
12 |
+
import subprocess
|
13 |
+
import threading
|
14 |
+
|
15 |
+
# Assumes meteor-1.5.jar is in the same directory as meteor.py. Change as needed.
|
16 |
+
METEOR_JAR = 'meteor-1.5.jar'
|
17 |
+
|
18 |
+
|
19 |
+
# print METEOR_JAR
|
20 |
+
|
21 |
+
class Meteor:
|
22 |
+
|
23 |
+
def __init__(self):
|
24 |
+
self.meteor_cmd = ['java', '-jar', '-Xmx2G', METEOR_JAR, \
|
25 |
+
'-', '-', '-stdio', '-l', 'en', '-norm']
|
26 |
+
self.meteor_p = subprocess.Popen(self.meteor_cmd, \
|
27 |
+
cwd=os.path.dirname(os.path.abspath(__file__)), \
|
28 |
+
stdin=subprocess.PIPE, \
|
29 |
+
stdout=subprocess.PIPE, \
|
30 |
+
stderr=subprocess.PIPE,
|
31 |
+
universal_newlines=True,
|
32 |
+
bufsize=1)
|
33 |
+
# Used to guarantee thread safety
|
34 |
+
self.lock = threading.Lock()
|
35 |
+
|
36 |
+
def compute_score(self, gts, res):
|
37 |
+
assert (gts.keys() == res.keys())
|
38 |
+
imgIds = gts.keys()
|
39 |
+
scores = []
|
40 |
+
|
41 |
+
eval_line = 'EVAL'
|
42 |
+
self.lock.acquire()
|
43 |
+
for i in imgIds:
|
44 |
+
assert (len(res[i]) == 1)
|
45 |
+
stat = self._stat(res[i][0], gts[i])
|
46 |
+
eval_line += ' ||| {}'.format(stat)
|
47 |
+
|
48 |
+
self.meteor_p.stdin.write('{}\n'.format(eval_line))
|
49 |
+
for i in range(0, len(imgIds)):
|
50 |
+
scores.append(float(self.meteor_p.stdout.readline().strip()))
|
51 |
+
score = float(self.meteor_p.stdout.readline().strip())
|
52 |
+
self.lock.release()
|
53 |
+
|
54 |
+
return score, scores
|
55 |
+
|
56 |
+
def method(self):
|
57 |
+
return "METEOR"
|
58 |
+
|
59 |
+
def _stat(self, hypothesis_str, reference_list):
|
60 |
+
# SCORE ||| reference 1 words ||| reference n words ||| hypothesis words
|
61 |
+
hypothesis_str = hypothesis_str.replace('|||', '').replace(' ', ' ')
|
62 |
+
score_line = ' ||| '.join(('SCORE', ' ||| '.join(reference_list), hypothesis_str))
|
63 |
+
self.meteor_p.stdin.write('{}\n'.format(score_line))
|
64 |
+
return self.meteor_p.stdout.readline().strip()
|
65 |
+
|
66 |
+
def _score(self, hypothesis_str, reference_list):
|
67 |
+
self.lock.acquire()
|
68 |
+
# SCORE ||| reference 1 words ||| reference n words ||| hypothesis words
|
69 |
+
hypothesis_str = hypothesis_str.replace('|||', '').replace(' ', ' ')
|
70 |
+
score_line = ' ||| '.join(('SCORE', ' ||| '.join(reference_list), hypothesis_str))
|
71 |
+
self.meteor_p.stdin.write('{}\n'.format(score_line))
|
72 |
+
stats = self.meteor_p.stdout.readline().strip()
|
73 |
+
eval_line = 'EVAL ||| {}'.format(stats)
|
74 |
+
# EVAL ||| stats
|
75 |
+
self.meteor_p.stdin.write('{}\n'.format(eval_line))
|
76 |
+
score = float(self.meteor_p.stdout.readline().strip())
|
77 |
+
# bug fix: there are two values returned by the jar file, one average, and one all, so do it twice
|
78 |
+
# thanks for Andrej for pointing this out
|
79 |
+
score = float(self.meteor_p.stdout.readline().strip())
|
80 |
+
self.lock.release()
|
81 |
+
return score
|
82 |
+
|
83 |
+
def __del__(self):
|
84 |
+
self.lock.acquire()
|
85 |
+
self.meteor_p.stdin.close()
|
86 |
+
self.meteor_p.kill()
|
87 |
+
self.meteor_p.wait()
|
88 |
+
self.lock.release()
|
pycocoevalcap/rouge/__init__.py
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
from .rouge import *
|
pycocoevalcap/rouge/rouge.py
ADDED
@@ -0,0 +1,105 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python
|
2 |
+
#
|
3 |
+
# File Name : rouge.py
|
4 |
+
#
|
5 |
+
# Description : Computes ROUGE-L metric as described by Lin and Hovey (2004)
|
6 |
+
#
|
7 |
+
# Creation Date : 2015-01-07 06:03
|
8 |
+
# Author : Ramakrishna Vedantam <vrama91@vt.edu>
|
9 |
+
|
10 |
+
import numpy as np
|
11 |
+
import pdb
|
12 |
+
|
13 |
+
def my_lcs(string, sub):
|
14 |
+
"""
|
15 |
+
Calculates longest common subsequence for a pair of tokenized strings
|
16 |
+
:param string : list of str : tokens from a string split using whitespace
|
17 |
+
:param sub : list of str : shorter string, also split using whitespace
|
18 |
+
:returns: length (list of int): length of the longest common subsequence between the two strings
|
19 |
+
|
20 |
+
Note: my_lcs only gives length of the longest common subsequence, not the actual LCS
|
21 |
+
"""
|
22 |
+
if(len(string)< len(sub)):
|
23 |
+
sub, string = string, sub
|
24 |
+
|
25 |
+
lengths = [[0 for i in range(0,len(sub)+1)] for j in range(0,len(string)+1)]
|
26 |
+
|
27 |
+
for j in range(1,len(sub)+1):
|
28 |
+
for i in range(1,len(string)+1):
|
29 |
+
if(string[i-1] == sub[j-1]):
|
30 |
+
lengths[i][j] = lengths[i-1][j-1] + 1
|
31 |
+
else:
|
32 |
+
lengths[i][j] = max(lengths[i-1][j] , lengths[i][j-1])
|
33 |
+
|
34 |
+
return lengths[len(string)][len(sub)]
|
35 |
+
|
36 |
+
class Rouge():
|
37 |
+
'''
|
38 |
+
Class for computing ROUGE-L score for a set of candidate sentences for the MS COCO test set
|
39 |
+
|
40 |
+
'''
|
41 |
+
def __init__(self):
|
42 |
+
# vrama91: updated the value below based on discussion with Hovey
|
43 |
+
self.beta = 1.2
|
44 |
+
|
45 |
+
def calc_score(self, candidate, refs):
|
46 |
+
"""
|
47 |
+
Compute ROUGE-L score given one candidate and references for an image
|
48 |
+
:param candidate: str : candidate sentence to be evaluated
|
49 |
+
:param refs: list of str : COCO reference sentences for the particular image to be evaluated
|
50 |
+
:returns score: int (ROUGE-L score for the candidate evaluated against references)
|
51 |
+
"""
|
52 |
+
assert(len(candidate)==1)
|
53 |
+
assert(len(refs)>0)
|
54 |
+
prec = []
|
55 |
+
rec = []
|
56 |
+
|
57 |
+
# split into tokens
|
58 |
+
token_c = candidate[0].split(" ")
|
59 |
+
|
60 |
+
for reference in refs:
|
61 |
+
# split into tokens
|
62 |
+
token_r = reference.split(" ")
|
63 |
+
# compute the longest common subsequence
|
64 |
+
lcs = my_lcs(token_r, token_c)
|
65 |
+
prec.append(lcs/float(len(token_c)))
|
66 |
+
rec.append(lcs/float(len(token_r)))
|
67 |
+
|
68 |
+
prec_max = max(prec)
|
69 |
+
rec_max = max(rec)
|
70 |
+
|
71 |
+
if(prec_max!=0 and rec_max !=0):
|
72 |
+
score = ((1 + self.beta**2)*prec_max*rec_max)/float(rec_max + self.beta**2*prec_max)
|
73 |
+
else:
|
74 |
+
score = 0.0
|
75 |
+
return score
|
76 |
+
|
77 |
+
def compute_score(self, gts, res):
|
78 |
+
"""
|
79 |
+
Computes Rouge-L score given a set of reference and candidate sentences for the dataset
|
80 |
+
Invoked by evaluate_captions.py
|
81 |
+
:param hypo_for_image: dict : candidate / test sentences with "image name" key and "tokenized sentences" as values
|
82 |
+
:param ref_for_image: dict : reference MS-COCO sentences with "image name" key and "tokenized sentences" as values
|
83 |
+
:returns: average_score: float (mean ROUGE-L score computed by averaging scores for all the images)
|
84 |
+
"""
|
85 |
+
assert(gts.keys() == res.keys())
|
86 |
+
imgIds = gts.keys()
|
87 |
+
|
88 |
+
score = []
|
89 |
+
for id in imgIds:
|
90 |
+
hypo = res[id]
|
91 |
+
ref = gts[id]
|
92 |
+
|
93 |
+
score.append(self.calc_score(hypo, ref))
|
94 |
+
|
95 |
+
# Sanity check.
|
96 |
+
assert(type(hypo) is list)
|
97 |
+
assert(len(hypo) == 1)
|
98 |
+
assert(type(ref) is list)
|
99 |
+
assert(len(ref) > 0)
|
100 |
+
|
101 |
+
average_score = np.mean(np.array(score))
|
102 |
+
return average_score, np.array(score)
|
103 |
+
|
104 |
+
def method(self):
|
105 |
+
return "Rouge"
|
pycocoevalcap/tokenizer/__init__.py
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
__author__ = 'hfang'
|
pycocoevalcap/tokenizer/ptbtokenizer.py
ADDED
@@ -0,0 +1,76 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python
|
2 |
+
#
|
3 |
+
# File Name : ptbtokenizer.py
|
4 |
+
#
|
5 |
+
# Description : Do the PTB Tokenization and remove punctuations.
|
6 |
+
#
|
7 |
+
# Creation Date : 29-12-2014
|
8 |
+
# Last Modified : Thu Mar 19 09:53:35 2015
|
9 |
+
# Authors : Hao Fang <hfang@uw.edu> and Tsung-Yi Lin <tl483@cornell.edu>
|
10 |
+
|
11 |
+
import os
|
12 |
+
import sys
|
13 |
+
import subprocess
|
14 |
+
import tempfile
|
15 |
+
import itertools
|
16 |
+
|
17 |
+
|
18 |
+
# Last modified : Wed 22 May 2019 08:10:00 PM EDT
|
19 |
+
# By Sabarish Sivanath
|
20 |
+
# To support Python 3
|
21 |
+
|
22 |
+
# path to the stanford corenlp jar
|
23 |
+
STANFORD_CORENLP_3_4_1_JAR = 'stanford-corenlp-3.4.1.jar'
|
24 |
+
|
25 |
+
# punctuations to be removed from the sentences
|
26 |
+
PUNCTUATIONS = ["''", "'", "``", "`", "-LRB-", "-RRB-", "-LCB-", "-RCB-", \
|
27 |
+
".", "?", "!", ",", ":", "-", "--", "...", ";"]
|
28 |
+
|
29 |
+
class PTBTokenizer:
|
30 |
+
"""Python wrapper of Stanford PTBTokenizer"""
|
31 |
+
|
32 |
+
def tokenize(self, captions_for_image):
|
33 |
+
cmd = ['java', '-cp', STANFORD_CORENLP_3_4_1_JAR, \
|
34 |
+
'edu.stanford.nlp.process.PTBTokenizer', \
|
35 |
+
'-preserveLines', '-lowerCase']
|
36 |
+
|
37 |
+
# ======================================================
|
38 |
+
# prepare data for PTB Tokenizer
|
39 |
+
# ======================================================
|
40 |
+
final_tokenized_captions_for_image = {}
|
41 |
+
image_id = [k for k, v in captions_for_image.items() for _ in range(len(v))]
|
42 |
+
sentences = '\n'.join([c['caption'].replace('\n', ' ') for k, v in captions_for_image.items() for c in v])
|
43 |
+
|
44 |
+
# ======================================================
|
45 |
+
# save sentences to temporary file
|
46 |
+
# ======================================================
|
47 |
+
path_to_jar_dirname=os.path.dirname(os.path.abspath(__file__))
|
48 |
+
tmp_file = tempfile.NamedTemporaryFile(delete=False, dir=path_to_jar_dirname)
|
49 |
+
tmp_file.write(sentences.encode('utf-8'))
|
50 |
+
tmp_file.close()
|
51 |
+
|
52 |
+
# ======================================================
|
53 |
+
# tokenize sentence
|
54 |
+
# ======================================================
|
55 |
+
cmd.append(os.path.basename(tmp_file.name))
|
56 |
+
p_tokenizer = subprocess.Popen(cmd,
|
57 |
+
cwd=path_to_jar_dirname,
|
58 |
+
stdout=subprocess.PIPE,
|
59 |
+
universal_newlines = True,
|
60 |
+
bufsize = 1)
|
61 |
+
token_lines = p_tokenizer.communicate(input=sentences.rstrip())[0]
|
62 |
+
lines = token_lines.split('\n')
|
63 |
+
# remove temp file
|
64 |
+
os.remove(tmp_file.name)
|
65 |
+
|
66 |
+
# ======================================================
|
67 |
+
# create dictionary for tokenized captions
|
68 |
+
# ======================================================
|
69 |
+
for k, line in zip(image_id, lines):
|
70 |
+
if not k in final_tokenized_captions_for_image:
|
71 |
+
final_tokenized_captions_for_image[k] = []
|
72 |
+
tokenized_caption = ' '.join([w for w in line.rstrip().split(' ') \
|
73 |
+
if w not in PUNCTUATIONS])
|
74 |
+
final_tokenized_captions_for_image[k].append(tokenized_caption)
|
75 |
+
|
76 |
+
return final_tokenized_captions_for_image
|