Antonio Cheong commited on
Commit
4d7378e
1 Parent(s): 4f15858
Files changed (14) hide show
  1. CODE_OF_CONDUCT.md +4 -0
  2. CONTRIBUTING.md +59 -0
  3. LICENSE +175 -0
  4. NOTICE +1 -0
  5. evaluations.py +100 -0
  6. main.py +383 -0
  7. model.py +194 -0
  8. requirements.txt +11 -0
  9. run_inference.sh +17 -0
  10. run_training.sh +15 -0
  11. utils_data.py +228 -0
  12. utils_evaluate.py +108 -0
  13. utils_prompt.py +240 -0
  14. vision_features/mm-cot.png +0 -0
CODE_OF_CONDUCT.md ADDED
@@ -0,0 +1,4 @@
 
 
 
 
1
+ ## Code of Conduct
2
+ This project has adopted the [Amazon Open Source Code of Conduct](https://aws.github.io/code-of-conduct).
3
+ For more information see the [Code of Conduct FAQ](https://aws.github.io/code-of-conduct-faq) or contact
4
+ opensource-codeofconduct@amazon.com with any additional questions or comments.
CONTRIBUTING.md ADDED
@@ -0,0 +1,59 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Contributing Guidelines
2
+
3
+ Thank you for your interest in contributing to our project. Whether it's a bug report, new feature, correction, or additional
4
+ documentation, we greatly value feedback and contributions from our community.
5
+
6
+ Please read through this document before submitting any issues or pull requests to ensure we have all the necessary
7
+ information to effectively respond to your bug report or contribution.
8
+
9
+
10
+ ## Reporting Bugs/Feature Requests
11
+
12
+ We welcome you to use the GitHub issue tracker to report bugs or suggest features.
13
+
14
+ When filing an issue, please check existing open, or recently closed, issues to make sure somebody else hasn't already
15
+ reported the issue. Please try to include as much information as you can. Details like these are incredibly useful:
16
+
17
+ * A reproducible test case or series of steps
18
+ * The version of our code being used
19
+ * Any modifications you've made relevant to the bug
20
+ * Anything unusual about your environment or deployment
21
+
22
+
23
+ ## Contributing via Pull Requests
24
+ Contributions via pull requests are much appreciated. Before sending us a pull request, please ensure that:
25
+
26
+ 1. You are working against the latest source on the *main* branch.
27
+ 2. You check existing open, and recently merged, pull requests to make sure someone else hasn't addressed the problem already.
28
+ 3. You open an issue to discuss any significant work - we would hate for your time to be wasted.
29
+
30
+ To send us a pull request, please:
31
+
32
+ 1. Fork the repository.
33
+ 2. Modify the source; please focus on the specific change you are contributing. If you also reformat all the code, it will be hard for us to focus on your change.
34
+ 3. Ensure local tests pass.
35
+ 4. Commit to your fork using clear commit messages.
36
+ 5. Send us a pull request, answering any default questions in the pull request interface.
37
+ 6. Pay attention to any automated CI failures reported in the pull request, and stay involved in the conversation.
38
+
39
+ GitHub provides additional document on [forking a repository](https://help.github.com/articles/fork-a-repo/) and
40
+ [creating a pull request](https://help.github.com/articles/creating-a-pull-request/).
41
+
42
+
43
+ ## Finding contributions to work on
44
+ Looking at the existing issues is a great way to find something to contribute on. As our projects, by default, use the default GitHub issue labels (enhancement/bug/duplicate/help wanted/invalid/question/wontfix), looking at any 'help wanted' issues is a great place to start.
45
+
46
+
47
+ ## Code of Conduct
48
+ This project has adopted the [Amazon Open Source Code of Conduct](https://aws.github.io/code-of-conduct).
49
+ For more information see the [Code of Conduct FAQ](https://aws.github.io/code-of-conduct-faq) or contact
50
+ opensource-codeofconduct@amazon.com with any additional questions or comments.
51
+
52
+
53
+ ## Security issue notifications
54
+ If you discover a potential security issue in this project we ask that you notify AWS/Amazon Security via our [vulnerability reporting page](http://aws.amazon.com/security/vulnerability-reporting/). Please do **not** create a public github issue.
55
+
56
+
57
+ ## Licensing
58
+
59
+ See the [LICENSE](LICENSE) file for our project's licensing. We will ask you to confirm the licensing of your contribution.
LICENSE ADDED
@@ -0,0 +1,175 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ Apache License
3
+ Version 2.0, January 2004
4
+ http://www.apache.org/licenses/
5
+
6
+ TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
7
+
8
+ 1. Definitions.
9
+
10
+ "License" shall mean the terms and conditions for use, reproduction,
11
+ and distribution as defined by Sections 1 through 9 of this document.
12
+
13
+ "Licensor" shall mean the copyright owner or entity authorized by
14
+ the copyright owner that is granting the License.
15
+
16
+ "Legal Entity" shall mean the union of the acting entity and all
17
+ other entities that control, are controlled by, or are under common
18
+ control with that entity. For the purposes of this definition,
19
+ "control" means (i) the power, direct or indirect, to cause the
20
+ direction or management of such entity, whether by contract or
21
+ otherwise, or (ii) ownership of fifty percent (50%) or more of the
22
+ outstanding shares, or (iii) beneficial ownership of such entity.
23
+
24
+ "You" (or "Your") shall mean an individual or Legal Entity
25
+ exercising permissions granted by this License.
26
+
27
+ "Source" form shall mean the preferred form for making modifications,
28
+ including but not limited to software source code, documentation
29
+ source, and configuration files.
30
+
31
+ "Object" form shall mean any form resulting from mechanical
32
+ transformation or translation of a Source form, including but
33
+ not limited to compiled object code, generated documentation,
34
+ and conversions to other media types.
35
+
36
+ "Work" shall mean the work of authorship, whether in Source or
37
+ Object form, made available under the License, as indicated by a
38
+ copyright notice that is included in or attached to the work
39
+ (an example is provided in the Appendix below).
40
+
41
+ "Derivative Works" shall mean any work, whether in Source or Object
42
+ form, that is based on (or derived from) the Work and for which the
43
+ editorial revisions, annotations, elaborations, or other modifications
44
+ represent, as a whole, an original work of authorship. For the purposes
45
+ of this License, Derivative Works shall not include works that remain
46
+ separable from, or merely link (or bind by name) to the interfaces of,
47
+ the Work and Derivative Works thereof.
48
+
49
+ "Contribution" shall mean any work of authorship, including
50
+ the original version of the Work and any modifications or additions
51
+ to that Work or Derivative Works thereof, that is intentionally
52
+ submitted to Licensor for inclusion in the Work by the copyright owner
53
+ or by an individual or Legal Entity authorized to submit on behalf of
54
+ the copyright owner. For the purposes of this definition, "submitted"
55
+ means any form of electronic, verbal, or written communication sent
56
+ to the Licensor or its representatives, including but not limited to
57
+ communication on electronic mailing lists, source code control systems,
58
+ and issue tracking systems that are managed by, or on behalf of, the
59
+ Licensor for the purpose of discussing and improving the Work, but
60
+ excluding communication that is conspicuously marked or otherwise
61
+ designated in writing by the copyright owner as "Not a Contribution."
62
+
63
+ "Contributor" shall mean Licensor and any individual or Legal Entity
64
+ on behalf of whom a Contribution has been received by Licensor and
65
+ subsequently incorporated within the Work.
66
+
67
+ 2. Grant of Copyright License. Subject to the terms and conditions of
68
+ this License, each Contributor hereby grants to You a perpetual,
69
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
70
+ copyright license to reproduce, prepare Derivative Works of,
71
+ publicly display, publicly perform, sublicense, and distribute the
72
+ Work and such Derivative Works in Source or Object form.
73
+
74
+ 3. Grant of Patent License. Subject to the terms and conditions of
75
+ this License, each Contributor hereby grants to You a perpetual,
76
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
77
+ (except as stated in this section) patent license to make, have made,
78
+ use, offer to sell, sell, import, and otherwise transfer the Work,
79
+ where such license applies only to those patent claims licensable
80
+ by such Contributor that are necessarily infringed by their
81
+ Contribution(s) alone or by combination of their Contribution(s)
82
+ with the Work to which such Contribution(s) was submitted. If You
83
+ institute patent litigation against any entity (including a
84
+ cross-claim or counterclaim in a lawsuit) alleging that the Work
85
+ or a Contribution incorporated within the Work constitutes direct
86
+ or contributory patent infringement, then any patent licenses
87
+ granted to You under this License for that Work shall terminate
88
+ as of the date such litigation is filed.
89
+
90
+ 4. Redistribution. You may reproduce and distribute copies of the
91
+ Work or Derivative Works thereof in any medium, with or without
92
+ modifications, and in Source or Object form, provided that You
93
+ meet the following conditions:
94
+
95
+ (a) You must give any other recipients of the Work or
96
+ Derivative Works a copy of this License; and
97
+
98
+ (b) You must cause any modified files to carry prominent notices
99
+ stating that You changed the files; and
100
+
101
+ (c) You must retain, in the Source form of any Derivative Works
102
+ that You distribute, all copyright, patent, trademark, and
103
+ attribution notices from the Source form of the Work,
104
+ excluding those notices that do not pertain to any part of
105
+ the Derivative Works; and
106
+
107
+ (d) If the Work includes a "NOTICE" text file as part of its
108
+ distribution, then any Derivative Works that You distribute must
109
+ include a readable copy of the attribution notices contained
110
+ within such NOTICE file, excluding those notices that do not
111
+ pertain to any part of the Derivative Works, in at least one
112
+ of the following places: within a NOTICE text file distributed
113
+ as part of the Derivative Works; within the Source form or
114
+ documentation, if provided along with the Derivative Works; or,
115
+ within a display generated by the Derivative Works, if and
116
+ wherever such third-party notices normally appear. The contents
117
+ of the NOTICE file are for informational purposes only and
118
+ do not modify the License. You may add Your own attribution
119
+ notices within Derivative Works that You distribute, alongside
120
+ or as an addendum to the NOTICE text from the Work, provided
121
+ that such additional attribution notices cannot be construed
122
+ as modifying the License.
123
+
124
+ You may add Your own copyright statement to Your modifications and
125
+ may provide additional or different license terms and conditions
126
+ for use, reproduction, or distribution of Your modifications, or
127
+ for any such Derivative Works as a whole, provided Your use,
128
+ reproduction, and distribution of the Work otherwise complies with
129
+ the conditions stated in this License.
130
+
131
+ 5. Submission of Contributions. Unless You explicitly state otherwise,
132
+ any Contribution intentionally submitted for inclusion in the Work
133
+ by You to the Licensor shall be under the terms and conditions of
134
+ this License, without any additional terms or conditions.
135
+ Notwithstanding the above, nothing herein shall supersede or modify
136
+ the terms of any separate license agreement you may have executed
137
+ with Licensor regarding such Contributions.
138
+
139
+ 6. Trademarks. This License does not grant permission to use the trade
140
+ names, trademarks, service marks, or product names of the Licensor,
141
+ except as required for reasonable and customary use in describing the
142
+ origin of the Work and reproducing the content of the NOTICE file.
143
+
144
+ 7. Disclaimer of Warranty. Unless required by applicable law or
145
+ agreed to in writing, Licensor provides the Work (and each
146
+ Contributor provides its Contributions) on an "AS IS" BASIS,
147
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
148
+ implied, including, without limitation, any warranties or conditions
149
+ of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
150
+ PARTICULAR PURPOSE. You are solely responsible for determining the
151
+ appropriateness of using or redistributing the Work and assume any
152
+ risks associated with Your exercise of permissions under this License.
153
+
154
+ 8. Limitation of Liability. In no event and under no legal theory,
155
+ whether in tort (including negligence), contract, or otherwise,
156
+ unless required by applicable law (such as deliberate and grossly
157
+ negligent acts) or agreed to in writing, shall any Contributor be
158
+ liable to You for damages, including any direct, indirect, special,
159
+ incidental, or consequential damages of any character arising as a
160
+ result of this License or out of the use or inability to use the
161
+ Work (including but not limited to damages for loss of goodwill,
162
+ work stoppage, computer failure or malfunction, or any and all
163
+ other commercial damages or losses), even if such Contributor
164
+ has been advised of the possibility of such damages.
165
+
166
+ 9. Accepting Warranty or Additional Liability. While redistributing
167
+ the Work or Derivative Works thereof, You may choose to offer,
168
+ and charge a fee for, acceptance of support, warranty, indemnity,
169
+ or other liability obligations and/or rights consistent with this
170
+ License. However, in accepting such obligations, You may act only
171
+ on Your own behalf and on Your sole responsibility, not on behalf
172
+ of any other Contributor, and only if You agree to indemnify,
173
+ defend, and hold each Contributor harmless for any liability
174
+ incurred by, or claims asserted against, such Contributor by reason
175
+ of your accepting any such warranty or additional liability.
NOTICE ADDED
@@ -0,0 +1 @@
 
1
+ Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
evaluations.py ADDED
@@ -0,0 +1,100 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ '''
2
+ Adapted from https://github.com/lupantech/ScienceQA
3
+ '''
4
+
5
+ import re
6
+ from rouge import Rouge
7
+ from nltk.translate.bleu_score import sentence_bleu
8
+ from sentence_transformers import util
9
+
10
+ ########################
11
+ ## BLEU
12
+ ########################
13
+ def tokenize(text):
14
+ tokens = re.split(r'\s|\.', text)
15
+ tokens = [t for t in tokens if len(t) > 0]
16
+ return tokens
17
+
18
+
19
+ def bleu_score(reference, hypothesis, gram):
20
+ reference_tokens = tokenize(reference)
21
+ hypothesis_tokens = tokenize(hypothesis)
22
+
23
+ if gram == 1:
24
+ bleu = sentence_bleu([reference_tokens], hypothesis_tokens, (1., )) # BELU-1
25
+ elif gram == 2:
26
+ bleu = sentence_bleu([reference_tokens], hypothesis_tokens, (1. / 2., 1. / 2.)) # BELU-2
27
+ elif gram == 3:
28
+ bleu = sentence_bleu([reference_tokens], hypothesis_tokens, (1. / 3., 1. / 3., 1. / 3.)) # BELU-3
29
+ elif gram == 4:
30
+ bleu = sentence_bleu([reference_tokens], hypothesis_tokens, (1. / 4., 1. / 4., 1. / 4., 1. / 4.)) # BELU-4
31
+
32
+ return bleu
33
+
34
+
35
+ def caculate_bleu(results, data, gram):
36
+ bleus = []
37
+ for qid, output in results.items():
38
+ prediction = output
39
+ target = data[qid]
40
+ target = target.strip()
41
+ if target == "":
42
+ continue
43
+ bleu = bleu_score(target, prediction, gram)
44
+ bleus.append(bleu)
45
+
46
+ avg_bleu = sum(bleus) / len(bleus)
47
+
48
+ return avg_bleu
49
+
50
+
51
+ ########################
52
+ ## Rouge-L
53
+ ########################
54
+ def score_rouge(str1, str2):
55
+ rouge = Rouge(metrics=["rouge-l"])
56
+ scores = rouge.get_scores(str1, str2, avg=True)
57
+ rouge_l = scores['rouge-l']['f']
58
+ return rouge_l
59
+
60
+
61
+ def caculate_rouge(results, data):
62
+ rouges = []
63
+ for qid, output in results.items():
64
+ prediction = output
65
+ target = data[qid]
66
+ target = target.strip()
67
+ if prediction == "":
68
+ continue
69
+ if target == "":
70
+ continue
71
+ rouge = score_rouge(target, prediction)
72
+ rouges.append(rouge)
73
+
74
+ avg_rouge = sum(rouges) / len(rouges)
75
+ return avg_rouge
76
+
77
+
78
+ ########################
79
+ ## Sentence Similarity
80
+ ########################
81
+ def similariry_score(str1, str2, model):
82
+ # compute embedding for both lists
83
+ embedding_1 = model.encode(str1, convert_to_tensor=True)
84
+ embedding_2 = model.encode(str2, convert_to_tensor=True)
85
+ score = util.pytorch_cos_sim(embedding_1, embedding_2).item()
86
+ return score
87
+
88
+
89
+ def caculate_similariry(results, data, model):
90
+ scores = []
91
+ for qid, output in results.items():
92
+ prediction = output
93
+ target = data[qid]
94
+ target = target.strip()
95
+
96
+ score = similariry_score(target, prediction, model)
97
+ scores.append(score)
98
+
99
+ avg_score = sum(scores) / len(scores)
100
+ return avg_score
main.py ADDED
@@ -0,0 +1,383 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import numpy as np
3
+ import torch
4
+ import os
5
+ import re
6
+ import json
7
+ import argparse
8
+ import random
9
+ from transformers import T5Tokenizer, DataCollatorForSeq2Seq, Seq2SeqTrainingArguments, Seq2SeqTrainer, T5ForConditionalGeneration
10
+ from model import T5ForConditionalGeneration, T5ForMultimodalGeneration
11
+ from utils_data import img_shape, load_data_std, load_data_img, ScienceQADatasetStd, ScienceQADatasetImg
12
+ from utils_prompt import *
13
+ from utils_evaluate import get_scores
14
+ from rich.table import Column, Table
15
+ from rich import box
16
+ from rich.console import Console
17
+ console = Console(record=True)
18
+ from torch import cuda
19
+ import nltk
20
+ import evaluate
21
+
22
+
23
+ def parse_args():
24
+ parser = argparse.ArgumentParser()
25
+ parser.add_argument('--data_root', type=str, default='data')
26
+ parser.add_argument('--output_dir', type=str, default='experiments')
27
+ parser.add_argument('--model', type=str, default='allenai/unifiedqa-t5-base')
28
+ parser.add_argument('--options', type=list, default=["A", "B", "C", "D", "E"])
29
+ parser.add_argument('--epoch', type=int, default=20)
30
+ parser.add_argument('--lr', type=float, default=5e-5)
31
+ parser.add_argument('--bs', type=int, default=16)
32
+ parser.add_argument('--input_len', type=int, default=512)
33
+ parser.add_argument('--output_len', type=int, default=64)
34
+ parser.add_argument('--eval_bs', type=int, default=16)
35
+ parser.add_argument('--eval_acc', type=int, default=None, help='evaluate accumulation step')
36
+ parser.add_argument('--train_split', type=str, default='train', choices=['train', 'trainval', 'minitrain'])
37
+ parser.add_argument('--val_split', type=str, default='val', choices=['test', 'val', 'minival'])
38
+ parser.add_argument('--test_split', type=str, default='test', choices=['test', 'minitest'])
39
+
40
+ parser.add_argument('--use_generate', action='store_true', help='only for baseline to improve inference speed')
41
+ parser.add_argument('--final_eval', action='store_true', help='only evaluate the model at the final epoch')
42
+ parser.add_argument('--user_msg', type=str, default="baseline", help='experiment type in the save_dir')
43
+ parser.add_argument('--img_type', type=str, default=None, choices=['detr', 'clip', 'resnet'], help='type of image features')
44
+ parser.add_argument('--eval_le', type=str, default=None, help='generated rationale for the dev set')
45
+ parser.add_argument('--test_le', type=str, default=None, help='generated rationale for the test set')
46
+ parser.add_argument('--evaluate_dir', type=str, default=None, help='the directory of model for evaluation')
47
+ parser.add_argument('--caption_file', type=str, default='data/captions.json')
48
+ parser.add_argument('--use_caption', action='store_true', help='use image captions or not')
49
+ parser.add_argument('--prompt_format', type=str, default='QCM-A', help='prompt format template',
50
+ choices=['QCM-A', 'QCM-LE', 'QCMG-A', 'QCM-LEA', 'QCM-ALE'])
51
+ parser.add_argument('--seed', type=int, default=42, help='random seed')
52
+
53
+ args = parser.parse_args()
54
+ return args
55
+
56
+ def T5Trainer(
57
+ dataframe, args,
58
+ ):
59
+ torch.manual_seed(args.seed) # pytorch random seed
60
+ np.random.seed(args.seed) # numpy random seed
61
+ torch.backends.cudnn.deterministic = True
62
+
63
+ if args.evaluate_dir is not None:
64
+ args.model = args.evaluate_dir
65
+
66
+ tokenizer = T5Tokenizer.from_pretrained(args.model)
67
+
68
+ console.log(f"""[Model]: Loading {args.model}...\n""")
69
+ console.log(f"[Data]: Reading data...\n")
70
+ problems = dataframe['problems']
71
+ qids = dataframe['qids']
72
+ train_qids = qids['train']
73
+ test_qids = qids['test']
74
+ val_qids = qids['val']
75
+
76
+ if args.evaluate_dir is not None:
77
+ save_dir = args.evaluate_dir
78
+ else:
79
+ model_name = args.model.replace("/","-")
80
+ gpu_count = torch.cuda.device_count()
81
+ save_dir = f"{args.output_dir}/{args.user_msg}_{model_name}_{args.img_type}_{args.prompt_format}_lr{args.lr}_bs{args.bs * gpu_count}_op{args.output_len}_ep{args.epoch}"
82
+ if not os.path.exists(save_dir):
83
+ os.mkdir(save_dir)
84
+
85
+ padding_idx = tokenizer._convert_token_to_id(tokenizer.pad_token)
86
+ if args.img_type is not None:
87
+ patch_size = img_shape[args.img_type]
88
+ model = T5ForMultimodalGeneration.from_pretrained(args.model, patch_size=patch_size, padding_idx=padding_idx, save_dir=save_dir)
89
+ name_maps = dataframe['name_maps']
90
+ image_features = dataframe['image_features']
91
+ train_set = ScienceQADatasetImg(
92
+ problems,
93
+ train_qids,
94
+ name_maps,
95
+ tokenizer,
96
+ args.input_len,
97
+ args.output_len,
98
+ args,
99
+ image_features,
100
+ )
101
+ eval_set = ScienceQADatasetImg(
102
+ problems,
103
+ val_qids,
104
+ name_maps,
105
+ tokenizer,
106
+ args.input_len,
107
+ args.output_len,
108
+ args,
109
+ image_features,
110
+ args.eval_le,
111
+ )
112
+ test_set = ScienceQADatasetImg(
113
+ problems,
114
+ test_qids,
115
+ name_maps,
116
+ tokenizer,
117
+ args.input_len,
118
+ args.output_len,
119
+ args,
120
+ image_features,
121
+ args.test_le,
122
+ )
123
+ else:
124
+ model = T5ForConditionalGeneration.from_pretrained(args.model)
125
+ train_set = ScienceQADatasetStd(
126
+ problems,
127
+ train_qids,
128
+ tokenizer,
129
+ args.input_len,
130
+ args.output_len,
131
+ args,
132
+ )
133
+ eval_set = ScienceQADatasetStd(
134
+ problems,
135
+ val_qids,
136
+ tokenizer,
137
+ args.input_len,
138
+ args.output_len,
139
+ args,
140
+ args.eval_le,
141
+ )
142
+
143
+ test_set = ScienceQADatasetStd(
144
+ problems,
145
+ test_qids,
146
+ tokenizer,
147
+ args.input_len,
148
+ args.output_len,
149
+ args,
150
+ args.test_le,
151
+ )
152
+
153
+ datacollator = DataCollatorForSeq2Seq(tokenizer)
154
+ print("model parameters: ", model.num_parameters())
155
+ def extract_ans(ans):
156
+ pattern = re.compile(r'The answer is \(([A-Z])\)')
157
+ res = pattern.findall(ans)
158
+
159
+ if len(res) == 1:
160
+ answer = res[0] # 'A', 'B', ...
161
+ else:
162
+ answer = "FAILED"
163
+ return answer
164
+
165
+ # accuracy for answer inference
166
+ def compute_metrics_acc(eval_preds):
167
+ if args.use_generate:
168
+ preds, targets = eval_preds
169
+ if isinstance(preds, tuple):
170
+ preds = preds[0]
171
+ else:
172
+ preds = eval_preds.predictions[0]
173
+ targets = eval_preds.label_ids
174
+ preds = preds.argmax(axis=2)
175
+ preds = tokenizer.batch_decode(preds, skip_special_tokens=True, clean_up_tokenization_spaces=True)
176
+ targets = tokenizer.batch_decode(targets, skip_special_tokens=True, clean_up_tokenization_spaces=True)
177
+ correct = 0
178
+ assert len(preds) == len(targets)
179
+ for idx, pred in enumerate(preds):
180
+ reference = targets[idx]
181
+ reference = extract_ans(reference)
182
+ extract_pred = extract_ans(pred)
183
+ best_option = extract_pred
184
+ if reference == best_option:
185
+ correct +=1
186
+ return {'accuracy': 1.0*correct/len(targets)}
187
+
188
+ # rougel for rationale generation
189
+ metric = evaluate.load("rouge")
190
+ def postprocess_text(preds, labels):
191
+ preds = [pred.strip() for pred in preds]
192
+ labels = [label.strip() for label in labels]
193
+ preds = ["\n".join(nltk.sent_tokenize(pred)) for pred in preds]
194
+ labels = ["\n".join(nltk.sent_tokenize(label)) for label in labels]
195
+ return preds, labels
196
+
197
+ def compute_metrics_rougel(eval_preds):
198
+ if args.use_generate:
199
+ preds, targets = eval_preds
200
+ if isinstance(preds, tuple):
201
+ preds = preds[0]
202
+ else:
203
+ preds = eval_preds.predictions[0]
204
+ targets = eval_preds.label_ids
205
+ preds = preds.argmax(axis=2)
206
+ preds = tokenizer.batch_decode(preds, skip_special_tokens=True, clean_up_tokenization_spaces=True)
207
+ targets = tokenizer.batch_decode(targets, skip_special_tokens=True, clean_up_tokenization_spaces=True)
208
+
209
+ decoded_preds, decoded_labels = postprocess_text(preds, targets)
210
+
211
+ result = metric.compute(predictions=decoded_preds, references=decoded_labels, use_stemmer=True)
212
+ result = {k: round(v * 100, 4) for k, v in result.items()}
213
+ prediction_lens = [np.count_nonzero(pred != tokenizer.pad_token_id) for pred in preds]
214
+ result["gen_len"] = np.mean(prediction_lens)
215
+ return result
216
+
217
+ # only use the last model for evaluation to save time
218
+ if args.final_eval:
219
+ training_args = Seq2SeqTrainingArguments(
220
+ save_dir,
221
+ do_train=True if args.evaluate_dir is None else False,
222
+ do_eval=False,
223
+ evaluation_strategy="no",
224
+ logging_strategy="steps",
225
+ save_strategy="epoch",
226
+ save_total_limit = 2,
227
+ learning_rate= args.lr,
228
+ eval_accumulation_steps=args.eval_acc,
229
+ per_device_train_batch_size=args.bs,
230
+ per_device_eval_batch_size=args.eval_bs,
231
+ weight_decay=0.01,
232
+ num_train_epochs=args.epoch,
233
+ predict_with_generate=args.use_generate,
234
+ report_to="none",
235
+ )
236
+ # evaluate at each epoch
237
+ else:
238
+ training_args = Seq2SeqTrainingArguments(
239
+ save_dir,
240
+ do_train=True if args.evaluate_dir is None else False,
241
+ do_eval=True,
242
+ evaluation_strategy="epoch",
243
+ logging_strategy="steps",
244
+ save_strategy="epoch",
245
+ save_total_limit = 2,
246
+ learning_rate= args.lr,
247
+ eval_accumulation_steps=args.eval_acc,
248
+ per_device_train_batch_size=args.bs,
249
+ per_device_eval_batch_size=args.eval_bs,
250
+ weight_decay=0.01,
251
+ num_train_epochs=args.epoch,
252
+ metric_for_best_model="accuracy" if args.prompt_format != "QCM-LE" else "rougeL",
253
+ predict_with_generate=args.use_generate,
254
+ load_best_model_at_end=True,
255
+ report_to="none",
256
+ )
257
+
258
+ trainer = Seq2SeqTrainer(
259
+ model=model,
260
+ args=training_args,
261
+ train_dataset=train_set,
262
+ eval_dataset=eval_set,
263
+ data_collator=datacollator,
264
+ tokenizer=tokenizer,
265
+ compute_metrics = compute_metrics_acc if args.prompt_format != "QCM-LE" else compute_metrics_rougel
266
+ )
267
+
268
+ if args.evaluate_dir is None:
269
+ trainer.train()
270
+ trainer.save_model(save_dir)
271
+
272
+ metrics = trainer.evaluate(eval_dataset = test_set)
273
+ trainer.log_metrics("test", metrics)
274
+ trainer.save_metrics("test", metrics)
275
+
276
+ predict_results = trainer.predict(test_dataset=test_set, max_length=args.output_len)
277
+ if trainer.is_world_process_zero():
278
+ if args.use_generate:
279
+ preds, targets = predict_results.predictions, predict_results.label_ids
280
+ else:
281
+ preds = predict_results.predictions[0]
282
+ targets = predict_results.label_ids
283
+ preds = preds.argmax(axis=2)
284
+
285
+ preds = tokenizer.batch_decode(
286
+ preds, skip_special_tokens=True, clean_up_tokenization_spaces=True
287
+ )
288
+ targets = tokenizer.batch_decode(
289
+ targets, skip_special_tokens=True, clean_up_tokenization_spaces=True
290
+ )
291
+
292
+ results_ans = {}
293
+ results_rationale = {}
294
+ results_reference = {}
295
+
296
+ num_fail = 0
297
+ for idx, qid in enumerate(test_qids):
298
+ pred = preds[int(idx)]
299
+ ref = targets[int(idx)]
300
+ extract_pred = extract_ans(pred)
301
+ if extract_pred != "FAILED":
302
+ if extract_pred in args.options:
303
+ extract_pred = args.options.index(extract_pred)
304
+ else:
305
+ extract_pred = random.choice(range(0,len(args.options)))
306
+ else:
307
+ num_fail += 1
308
+ extract_pred = random.choice(range(len(args.options))) # random choose one option
309
+ results_ans[str(qid)] = extract_pred
310
+ results_rationale[str(qid)] = pred
311
+ results_reference[str(qid)] = ref
312
+
313
+ scores = get_scores(results_ans, results_rationale, results_reference, os.path.join(args.data_root, "scienceqa/problems.json"))
314
+ preds = [pred.strip() for pred in preds]
315
+ output_data = {
316
+ "num_fail": num_fail,
317
+ "scores": scores,
318
+ "preds": preds,
319
+ "labels": targets}
320
+ output_prediction_file = os.path.join(save_dir,"predictions_ans_test.json")
321
+ with open(output_prediction_file, "w") as writer:
322
+ writer.write(json.dumps(output_data, indent=4))
323
+
324
+ # generate the rationale for the eval set
325
+ if args.prompt_format == "QCM-LE":
326
+ torch.cuda.empty_cache()
327
+ del predict_results, preds, targets
328
+ predict_results = trainer.predict(test_dataset=eval_set, max_length=args.output_len)
329
+ if trainer.is_world_process_zero():
330
+ if args.use_generate:
331
+ preds, targets = predict_results.predictions, predict_results.label_ids
332
+ else:
333
+ preds = predict_results.predictions[0]
334
+ targets = predict_results.label_ids
335
+ preds = preds.argmax(axis=2)
336
+
337
+ preds = tokenizer.batch_decode(
338
+ preds, skip_special_tokens=True, clean_up_tokenization_spaces=True
339
+ )
340
+ targets = tokenizer.batch_decode(
341
+ targets, skip_special_tokens=True, clean_up_tokenization_spaces=True
342
+ )
343
+ preds = [pred.strip() for pred in preds]
344
+ output_data = {"preds": preds,
345
+ "labels": targets}
346
+ output_prediction_file = os.path.join(save_dir,"predictions_ans_eval.json")
347
+ with open(output_prediction_file, "w") as writer:
348
+ writer.write(json.dumps(output_data, indent=4))
349
+
350
+
351
+ if __name__ == '__main__':
352
+
353
+ # training logger to log training progress
354
+ training_logger = Table(
355
+ Column("Epoch", justify="center"),
356
+ Column("Steps", justify="center"),
357
+ Column("Loss", justify="center"),
358
+ title="Training Status",
359
+ pad_edge=False,
360
+ box=box.ASCII,
361
+ )
362
+
363
+ args = parse_args()
364
+ print("args",args)
365
+ print('====Input Arguments====')
366
+ print(json.dumps(vars(args), indent=2, sort_keys=False))
367
+
368
+ random.seed(args.seed)
369
+
370
+ if not os.path.exists(args.output_dir):
371
+ os.mkdir(args.output_dir)
372
+
373
+ if args.img_type is not None:
374
+ problems, qids, name_maps, image_features = load_data_img(args) # probelms, test question ids, shot example ids
375
+ dataframe = {'problems':problems, 'qids':qids, 'name_maps': name_maps, 'image_features': image_features}
376
+ else:
377
+ problems, qids = load_data_std(args) # probelms, test question ids, shot example ids
378
+ dataframe = {'problems':problems, 'qids':qids}
379
+
380
+ T5Trainer(
381
+ dataframe=dataframe,
382
+ args = args
383
+ )
model.py ADDED
@@ -0,0 +1,194 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ '''
2
+ Adapted from https://github.com/huggingface/transformers
3
+ '''
4
+
5
+ from transformers import T5Config, T5ForConditionalGeneration
6
+ from transformers.models.t5.modeling_t5 import T5Stack, __HEAD_MASK_WARNING_MSG, T5EncoderModel
7
+ import copy
8
+ import math
9
+ import os
10
+ import warnings
11
+ from typing import Optional, Tuple, Union
12
+ import torch
13
+ from torch import nn
14
+ from torch.nn import CrossEntropyLoss
15
+ from transformers.modeling_outputs import (
16
+ BaseModelOutput,
17
+ Seq2SeqLMOutput,
18
+ )
19
+
20
+ class T5ForMultimodalGeneration(T5ForConditionalGeneration):
21
+ _keys_to_ignore_on_load_missing = [
22
+ r"encoder.embed_tokens.weight",
23
+ r"decoder.embed_tokens.weight",
24
+ r"lm_head.weight",
25
+ ]
26
+ _keys_to_ignore_on_load_unexpected = [
27
+ r"decoder.block.0.layer.1.EncDecAttention.relative_attention_bias.weight",
28
+ ]
29
+
30
+ def __init__(self, config: T5Config, patch_size, padding_idx, save_dir):
31
+ super().__init__(config)
32
+ self.model_dim = config.d_model
33
+
34
+ self.padding_idx = padding_idx
35
+ self.out = open(os.path.join(save_dir, 'gate.txt'), 'w')
36
+
37
+ self.shared = nn.Embedding(config.vocab_size, config.d_model)
38
+ self.patch_num, self.patch_dim = patch_size
39
+
40
+ self.image_dense = nn.Linear(self.patch_dim, config.d_model)
41
+ self.mha_layer = torch.nn.MultiheadAttention(embed_dim=config.hidden_size, kdim=config.hidden_size, vdim=config.hidden_size, num_heads=1, batch_first=True)
42
+ self.gate_dense = nn.Linear(2*config.hidden_size, config.hidden_size)
43
+ self.sigmoid = nn.Sigmoid()
44
+
45
+ encoder_config = copy.deepcopy(config)
46
+ encoder_config.is_decoder = False
47
+ encoder_config.use_cache = False
48
+ encoder_config.is_encoder_decoder = False
49
+ self.encoder = T5Stack(encoder_config, self.shared)
50
+
51
+ decoder_config = copy.deepcopy(config)
52
+ decoder_config.is_decoder = True
53
+ decoder_config.is_encoder_decoder = False
54
+ decoder_config.num_layers = config.num_decoder_layers
55
+ self.decoder = T5Stack(decoder_config, self.shared)
56
+
57
+ self.lm_head = nn.Linear(config.d_model, config.vocab_size, bias=False)
58
+
59
+ # Initialize weights and apply final processing
60
+ self.post_init()
61
+
62
+ # Model parallel
63
+ self.model_parallel = False
64
+ self.device_map = None
65
+
66
+ def forward(
67
+ self,
68
+ input_ids: Optional[torch.LongTensor] = None,
69
+ image_ids=None,
70
+ attention_mask: Optional[torch.FloatTensor] = None,
71
+ decoder_input_ids: Optional[torch.LongTensor] = None,
72
+ decoder_attention_mask: Optional[torch.BoolTensor] = None,
73
+ head_mask: Optional[torch.FloatTensor] = None,
74
+ decoder_head_mask: Optional[torch.FloatTensor] = None,
75
+ cross_attn_head_mask: Optional[torch.Tensor] = None,
76
+ encoder_outputs: Optional[Tuple[Tuple[torch.Tensor]]] = None,
77
+ past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None,
78
+ inputs_embeds: Optional[torch.FloatTensor] = None,
79
+ decoder_inputs_embeds: Optional[torch.FloatTensor] = None,
80
+ labels: Optional[torch.LongTensor] = None,
81
+ use_cache: Optional[bool] = None,
82
+ output_attentions: Optional[bool] = None,
83
+ output_hidden_states: Optional[bool] = None,
84
+ return_dict: Optional[bool] = None,
85
+ ) -> Union[Tuple[torch.FloatTensor], Seq2SeqLMOutput]:
86
+ use_cache = use_cache if use_cache is not None else self.config.use_cache
87
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
88
+
89
+ # FutureWarning: head_mask was separated into two input args - head_mask, decoder_head_mask
90
+ if head_mask is not None and decoder_head_mask is None:
91
+ if self.config.num_layers == self.config.num_decoder_layers:
92
+ warnings.warn(__HEAD_MASK_WARNING_MSG, FutureWarning)
93
+ decoder_head_mask = head_mask
94
+
95
+ # Encode if needed (training, first prediction pass)
96
+ if encoder_outputs is None:
97
+ # Convert encoder inputs in embeddings if needed
98
+ encoder_outputs = self.encoder(
99
+ input_ids=input_ids,
100
+ attention_mask=attention_mask,
101
+ inputs_embeds=inputs_embeds,
102
+ head_mask=head_mask,
103
+ output_attentions=output_attentions,
104
+ output_hidden_states=output_hidden_states,
105
+ return_dict=return_dict,
106
+ )
107
+
108
+ elif return_dict and not isinstance(encoder_outputs, BaseModelOutput):
109
+ encoder_outputs = BaseModelOutput(
110
+ last_hidden_state=encoder_outputs[0],
111
+ hidden_states=encoder_outputs[1] if len(encoder_outputs) > 1 else None,
112
+ attentions=encoder_outputs[2] if len(encoder_outputs) > 2 else None,
113
+ )
114
+
115
+
116
+ hidden_states = encoder_outputs[0]
117
+
118
+ image_embedding = self.image_dense(image_ids)
119
+ image_att, _ = self.mha_layer(hidden_states, image_embedding, image_embedding)
120
+
121
+ merge = torch.cat([hidden_states, image_att], dim=-1)
122
+ gate = self.sigmoid(self.gate_dense(merge))
123
+ hidden_states = (1 - gate) * hidden_states + gate * image_att
124
+
125
+ if self.model_parallel:
126
+ torch.cuda.set_device(self.decoder.first_device)
127
+
128
+ if labels is not None and decoder_input_ids is None and decoder_inputs_embeds is None:
129
+ # get decoder inputs from shifting lm labels to the right
130
+ decoder_input_ids = self._shift_right(labels)
131
+
132
+ # Set device for model parallelism
133
+ if self.model_parallel:
134
+ torch.cuda.set_device(self.decoder.first_device)
135
+ hidden_states = hidden_states.to(self.decoder.first_device)
136
+ if decoder_input_ids is not None:
137
+ decoder_input_ids = decoder_input_ids.to(self.decoder.first_device)
138
+ if attention_mask is not None:
139
+ attention_mask = attention_mask.to(self.decoder.first_device)
140
+ if decoder_attention_mask is not None:
141
+ decoder_attention_mask = decoder_attention_mask.to(self.decoder.first_device)
142
+
143
+ # Decode
144
+ decoder_outputs = self.decoder(
145
+ input_ids=decoder_input_ids,
146
+ attention_mask=decoder_attention_mask,
147
+ inputs_embeds=decoder_inputs_embeds,
148
+ past_key_values=past_key_values,
149
+ encoder_hidden_states=hidden_states,
150
+ encoder_attention_mask=attention_mask,
151
+ head_mask=decoder_head_mask,
152
+ cross_attn_head_mask=cross_attn_head_mask,
153
+ use_cache=use_cache,
154
+ output_attentions=output_attentions,
155
+ output_hidden_states=output_hidden_states,
156
+ return_dict=return_dict,
157
+ )
158
+
159
+ sequence_output = decoder_outputs[0]
160
+
161
+ # Set device for model parallelism
162
+ if self.model_parallel:
163
+ torch.cuda.set_device(self.encoder.first_device)
164
+ self.lm_head = self.lm_head.to(self.encoder.first_device)
165
+ sequence_output = sequence_output.to(self.lm_head.weight.device)
166
+
167
+ if self.config.tie_word_embeddings:
168
+ # Rescale output before projecting on vocab
169
+ # See https://github.com/tensorflow/mesh/blob/fa19d69eafc9a482aff0b59ddd96b025c0cb207d/mesh_tensorflow/transformer/transformer.py#L586
170
+ sequence_output = sequence_output * (self.model_dim**-0.5)
171
+
172
+ lm_logits = self.lm_head(sequence_output)
173
+
174
+ loss = None
175
+ if labels is not None:
176
+ loss_fct = CrossEntropyLoss(ignore_index=-100)
177
+ loss = loss_fct(lm_logits.view(-1, lm_logits.size(-1)), labels.view(-1))
178
+ # TODO(thom): Add z_loss https://github.com/tensorflow/mesh/blob/fa19d69eafc9a482aff0b59ddd96b025c0cb207d/mesh_tensorflow/layers.py#L666
179
+
180
+ if not return_dict:
181
+ output = (lm_logits,) + decoder_outputs[1:] + encoder_outputs
182
+ return ((loss,) + output) if loss is not None else output
183
+
184
+ return Seq2SeqLMOutput(
185
+ loss=loss,
186
+ logits=lm_logits,
187
+ past_key_values=decoder_outputs.past_key_values,
188
+ decoder_hidden_states=decoder_outputs.hidden_states,
189
+ decoder_attentions=decoder_outputs.attentions,
190
+ cross_attentions=decoder_outputs.cross_attentions,
191
+ encoder_last_hidden_state=encoder_outputs.last_hidden_state,
192
+ encoder_hidden_states=encoder_outputs.hidden_states,
193
+ encoder_attentions=encoder_outputs.attentions,
194
+ )
requirements.txt ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ huggingface-hub==0.0.12
2
+ numpy==1.23.2
3
+ openai==0.23.0
4
+ pandas==1.4.3
5
+ rouge==1.0.1
6
+ sentence-transformers==2.2.2
7
+ transformers==4.21.1
8
+ nltk==3.6.6
9
+ evaluate==0.4.0
10
+ rouge==1.0.1
11
+ rouge_score==0.1.2
run_inference.sh ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # rationale generation
2
+ CUDA_VISIBLE_DEVICES=0,1 python main.py \
3
+ --model allenai/unifiedqa-t5-base \
4
+ --user_msg rationale --img_type detr \
5
+ --bs 8 --eval_bs 4 --eval_acc 10 --output_len 512 \
6
+ --final_eval --prompt_format QCM-LE \
7
+ --evaluate_dir models/rationale
8
+
9
+ # answer inference
10
+ CUDA_VISIBLE_DEVICES=0,1 python main.py \
11
+ --model allenai/unifiedqa-t5-base \
12
+ --user_msg answer --img_type detr \
13
+ --bs 8 --eval_bs 4 --eval_acc 10 --output_len 64 \
14
+ --final_eval --prompt_format QCMG-A \
15
+ --eval_le models/rationale/predictions_ans_eval.json \
16
+ --test_le models/rationale/predictions_ans_test.json \
17
+ --evaluate_dir models/answer
run_training.sh ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # rationale generation
2
+ CUDA_VISIBLE_DEVICES=0,1 python main.py \
3
+ --model allenai/unifiedqa-t5-base \
4
+ --user_msg rationale --img_type detr \
5
+ --bs 8 --eval_bs 4 --eval_acc 10 --output_len 512 \
6
+ --final_eval --prompt_format QCM-LE
7
+
8
+ # answer inference
9
+ CUDA_VISIBLE_DEVICES=0,1 python main.py \
10
+ --model allenai/unifiedqa-t5-base \
11
+ --user_msg answer --img_type detr \
12
+ --bs 8 --eval_bs 4 --eval_acc 10 --output_len 64 \
13
+ --final_eval --prompt_format QCMG-A \
14
+ --eval_le experiments/rationale_allenai-unifiedqa-t5-base_detr_QCM-LE_lr5e-05_bs16_op512_ep20/predictions_ans_eval.json \
15
+ --test_le experiments/rationale_allenai-unifiedqa-t5-base_detr_QCM-LE_lr5e-05_bs16_op512_ep20/predictions_ans_test.json
utils_data.py ADDED
@@ -0,0 +1,228 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from torch.utils.data import Dataset
3
+ import os
4
+ import json
5
+ import numpy as np
6
+ import torch
7
+ from utils_prompt import *
8
+
9
+ img_shape = {
10
+ "resnet": (512, 2048),
11
+ "clip": (49, 2048),
12
+ "detr": (100, 256),
13
+ }
14
+
15
+ def load_data_std(args):
16
+ problems = json.load(open(os.path.join(args.data_root, 'scienceqa/problems.json')))
17
+ pid_splits = json.load(open(os.path.join(args.data_root, 'scienceqa/pid_splits.json')))
18
+ captions = json.load(open(args.caption_file))["captions"]
19
+
20
+ for qid in problems:
21
+ problems[qid]['caption'] = captions[qid] if qid in captions else ""
22
+
23
+ train_qids = pid_splits['%s' % (args.train_split)]
24
+ val_qids = pid_splits['%s' % (args.val_split)]
25
+ test_qids = pid_splits['%s' % (args.test_split)]
26
+ print(f"number of train problems: {len(train_qids)}\n")
27
+ print(f"number of val problems: {len(val_qids)}\n")
28
+ print(f"number of test problems: {len(test_qids)}\n")
29
+
30
+ qids = {'train': train_qids, 'val':val_qids,'test':test_qids}
31
+ return problems, qids,
32
+
33
+ def load_data_img(args):
34
+ problems = json.load(open(os.path.join(args.data_root, 'scienceqa/problems.json')))
35
+ pid_splits = json.load(open(os.path.join(args.data_root, 'scienceqa/pid_splits.json')))
36
+ captions = json.load(open(args.caption_file))["captions"]
37
+ name_maps = json.load(open('vision_features/name_map.json'))
38
+
39
+ # check
40
+ if args.img_type == "resnet":
41
+ image_features = np.load('vision_features/resnet.npy')
42
+ image_features = np.expand_dims(image_features, axis=1)
43
+ image_features = image_features.repeat(512, axis=1)
44
+ elif args.img_type == "clip":
45
+ image_features = np.load('vision_features/clip.npy')
46
+ elif args.img_type == "detr":
47
+ image_features = np.load('vision_features/detr.npy')
48
+ else:
49
+ image_features = np.load('vision_features/detr.npy')
50
+ print("img_features size: ", image_features.shape)
51
+
52
+ for qid in problems:
53
+ problems[qid]['caption'] = captions[qid] if qid in captions else ""
54
+
55
+ train_qids = pid_splits['%s' % (args.train_split)]
56
+ val_qids = pid_splits['%s' % (args.val_split)]
57
+ test_qids = pid_splits['%s' % (args.test_split)]
58
+ print(f"number of train problems: {len(train_qids)}\n")
59
+ print(f"number of val problems: {len(val_qids)}\n")
60
+ print(f"number of test problems: {len(test_qids)}\n")
61
+
62
+ qids = {'train': train_qids, 'val':val_qids,'test':test_qids}
63
+ return problems, qids, name_maps, image_features
64
+
65
+ class ScienceQADatasetStd(Dataset):
66
+ """
67
+ Creating a custom dataset for reading the dataset and
68
+ loading it into the dataloader to pass it to the
69
+ neural network for finetuning the model
70
+
71
+ """
72
+
73
+ def __init__(
74
+ self, problems, qids, tokenizer, source_len, target_len, args, test_le=None
75
+ ):
76
+ self.tokenizer = tokenizer
77
+ self.data = {qid : problems[qid] for qid in qids}
78
+ self.source_len = source_len
79
+ self.summ_len = target_len
80
+ self.target_text = []
81
+ self.source_text = []
82
+ if test_le is not None:
83
+ test_le_data =json.load(open(test_le))["preds"]
84
+ else:
85
+ test_le_data = None
86
+ idx = 0
87
+ for qid in self.data:
88
+ if test_le_data is not None:
89
+ curr_le_data = test_le_data[idx]
90
+ idx += 1
91
+ else:
92
+ curr_le_data = None
93
+ prompt, target = build_train_pair(problems, qid, args, curr_le_data)
94
+ self.target_text.append(target)
95
+ self.source_text.append(prompt)
96
+
97
+ def __len__(self):
98
+ return len(self.target_text)
99
+
100
+ def __getitem__(self, index):
101
+ source_text = str(self.source_text[index])
102
+ target_text = str(self.target_text[index])
103
+
104
+ # cleaning data so as to ensure data is in string type
105
+ source_text = " ".join(source_text.split())
106
+ target_text = " ".join(target_text.split())
107
+
108
+ source = self.tokenizer.batch_encode_plus(
109
+ [source_text],
110
+ max_length=self.source_len,
111
+ pad_to_max_length=True,
112
+ truncation=True,
113
+ padding="max_length",
114
+ return_tensors="pt",
115
+ )
116
+ target = self.tokenizer.batch_encode_plus(
117
+ [target_text],
118
+ max_length=self.summ_len,
119
+ pad_to_max_length=True,
120
+ truncation=True,
121
+ padding="max_length",
122
+ return_tensors="pt",
123
+ )
124
+ source_ids = source["input_ids"].squeeze()
125
+ source_mask = source["attention_mask"].squeeze()
126
+ target_ids = target["input_ids"].squeeze().tolist()
127
+
128
+ return {
129
+ "input_ids": source_ids,
130
+ "attention_mask": source_mask,
131
+ "labels": target_ids,
132
+ }
133
+
134
+
135
+ class ScienceQADatasetImg(Dataset):
136
+ """
137
+ Creating a custom dataset for reading the dataset and
138
+ loading it into the dataloader to pass it to the
139
+ neural network for finetuning the model
140
+
141
+ """
142
+
143
+ def __init__(
144
+ self, problems, qids, name_maps, tokenizer, source_len, target_len, args, image_features, test_le=None
145
+ ):
146
+ """
147
+ Initializes a Dataset class
148
+
149
+ Args:
150
+ dataframe (pandas.DataFrame): Input dataframe
151
+ tokenizer (transformers.tokenizer): Transformers tokenizer
152
+ source_len (int): Max length of source text
153
+ target_len (int): Max length of target text
154
+ source_text (str): column name of source text
155
+ target_text (str): column name of target text
156
+ """
157
+ self.tokenizer = tokenizer
158
+ self.data = {qid : problems[qid] for qid in qids}
159
+ self.source_len = source_len
160
+ self.summ_len = target_len
161
+ self.target_text = []
162
+ self.source_text = []
163
+ self.image_ids = []
164
+ if test_le is not None:
165
+ test_le_data =json.load(open(test_le))["preds"]
166
+ else:
167
+ test_le_data = None
168
+ idx = 0
169
+ for qid in self.data:
170
+ if test_le_data is not None:
171
+ curr_le_data = test_le_data[idx]
172
+ idx += 1
173
+ else:
174
+ curr_le_data = None
175
+ prompt, target = build_train_pair(problems, qid, args, curr_le_data)
176
+ self.target_text.append(target)
177
+ self.source_text.append(prompt)
178
+ if str(qid) in name_maps:
179
+ i_vectors = image_features[int(name_maps[str(qid)])]
180
+ self.image_ids.append(i_vectors)
181
+ else:
182
+ shape = img_shape[args.img_type]
183
+ self.image_ids.append(np.zeros(shape))
184
+
185
+ def __len__(self):
186
+ """returns the length of dataframe"""
187
+
188
+ return len(self.target_text)
189
+
190
+ def __getitem__(self, index):
191
+ """return the input ids, attention masks and target ids"""
192
+
193
+ source_text = str(self.source_text[index])
194
+ target_text = str(self.target_text[index])
195
+ image_ids = self.image_ids[index]
196
+
197
+ # cleaning data so as to ensure data is in string type
198
+ source_text = " ".join(source_text.split())
199
+ target_text = " ".join(target_text.split())
200
+
201
+ source = self.tokenizer.batch_encode_plus(
202
+ [source_text],
203
+ max_length=self.source_len,
204
+ pad_to_max_length=True,
205
+ truncation=True,
206
+ padding="max_length",
207
+ return_tensors="pt",
208
+ )
209
+ target = self.tokenizer.batch_encode_plus(
210
+ [target_text],
211
+ max_length=self.summ_len,
212
+ pad_to_max_length=True,
213
+ truncation=True,
214
+ padding="max_length",
215
+ return_tensors="pt",
216
+ )
217
+ source_ids = source["input_ids"].squeeze()
218
+ source_mask = source["attention_mask"].squeeze()
219
+ target_ids = target["input_ids"].squeeze().tolist()
220
+
221
+ image_ids = torch.tensor(image_ids).squeeze()
222
+
223
+ return {
224
+ "input_ids": source_ids,
225
+ "attention_mask": source_mask,
226
+ "image_ids": image_ids,
227
+ "labels": target_ids,
228
+ }
utils_evaluate.py ADDED
@@ -0,0 +1,108 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ '''
2
+ Adapted from https://github.com/lupantech/ScienceQA
3
+ '''
4
+
5
+ import os
6
+ import json
7
+ import argparse
8
+ import warnings
9
+ import pandas as pd
10
+ from sentence_transformers import SentenceTransformer
11
+ from evaluations import caculate_bleu, caculate_rouge, caculate_similariry
12
+
13
+ warnings.filterwarnings('ignore')
14
+
15
+ def get_acc_with_contion(res_pd, key, values):
16
+ if isinstance(values, list):
17
+ total_pd = res_pd[res_pd[key].isin(values)]
18
+ else:
19
+ total_pd = res_pd[res_pd[key] == values]
20
+ correct_pd = total_pd[total_pd['true_false'] == True]
21
+ acc = "{:.2f}".format(len(correct_pd) / len(total_pd) * 100)
22
+ return acc
23
+
24
+
25
+ def get_scores(result_data, rationale_data, results_reference, data_file):
26
+ # read result file
27
+ results = result_data
28
+ num = len(results)
29
+ assert num == 4241
30
+ #print("number of questions:", num)
31
+
32
+ # read data file
33
+ sqa_data = json.load(open(data_file))
34
+
35
+ # construct pandas data
36
+ sqa_pd = pd.DataFrame(sqa_data).T
37
+ res_pd = sqa_pd[sqa_pd['split'] == 'test'] # test set
38
+
39
+ # update data
40
+ for index, row in res_pd.iterrows():
41
+
42
+ res_pd.loc[index, 'no_context'] = True if (not row['hint'] and not row['image']) else False
43
+ res_pd.loc[index, 'has_text'] = True if row['hint'] else False
44
+ res_pd.loc[index, 'has_image'] = True if row['image'] else False
45
+ res_pd.loc[index, 'has_text_image'] = True if (row['hint'] and row['image']) else False
46
+
47
+ label = row['answer']
48
+ pred = int(results[index])
49
+ res_pd.loc[index, 'pred'] = pred
50
+ res_pd.loc[index, 'true_false'] = (label == pred)
51
+
52
+ # accuracy scores
53
+ acc_average = len(res_pd[res_pd['true_false'] == True]) / num * 100
54
+ #assert result_file.split('_')[-1] == "{:.3f}.json".format(acc_average)
55
+
56
+
57
+ # rationale quality
58
+
59
+ ## BLEU
60
+ bleu1 = caculate_bleu(rationale_data, results_reference, gram=1)
61
+ bleu4 = caculate_bleu(rationale_data, results_reference, gram=4)
62
+
63
+ ## Rouge-L
64
+ rouge = caculate_rouge(rationale_data, results_reference)
65
+
66
+ ## Similarity
67
+ model = SentenceTransformer('sentence-transformers/all-MiniLM-L6-v2').cuda()
68
+ similariry = caculate_similariry(rationale_data, results_reference, model)
69
+
70
+ scores = {
71
+ "answer":{
72
+ 'acc_natural':
73
+ get_acc_with_contion(res_pd, 'subject', 'natural science'),
74
+ 'acc_social':
75
+ get_acc_with_contion(res_pd, 'subject', 'social science'),
76
+ 'acc_language':
77
+ get_acc_with_contion(res_pd, 'subject', 'language science'),
78
+ 'acc_has_text':
79
+ get_acc_with_contion(res_pd, 'has_text', True),
80
+ 'acc_has_image':
81
+ get_acc_with_contion(res_pd, 'has_image', True),
82
+ 'acc_no_context':
83
+ get_acc_with_contion(res_pd, 'no_context', True),
84
+ 'acc_grade_1_6':
85
+ get_acc_with_contion(res_pd, 'grade', ['grade1', 'grade2', 'grade3', 'grade4', 'grade5', 'grade6']),
86
+ 'acc_grade_7_12':
87
+ get_acc_with_contion(res_pd, 'grade', ['grade7', 'grade8', 'grade9', 'grade10', 'grade11', 'grade12']),
88
+ 'acc_average':
89
+ "{:.2f}".format(acc_average),
90
+ },
91
+ "rationale":{
92
+ 'bleu1': bleu1 * 100,
93
+ 'bleu4': bleu4 * 100,
94
+ 'rouge': rouge * 100,
95
+ 'similariry': similariry * 100,
96
+ }
97
+ }
98
+
99
+ return scores
100
+
101
+
102
+ def print_scores(scores):
103
+ latex_output = ""
104
+ for key, score in scores.items():
105
+ print(f"{key[4:]}: \t{score}")
106
+ latex_output += f"& {score} "
107
+ latex_output += "\\\\"
108
+ print(latex_output)
utils_prompt.py ADDED
@@ -0,0 +1,240 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ '''
2
+ Adapted from https://github.com/lupantech/ScienceQA
3
+ '''
4
+
5
+ from dataclasses import dataclass
6
+ from typing import List, Optional
7
+
8
+ def get_question_text(problem):
9
+ question = problem['question']
10
+ return question
11
+
12
+
13
+ def get_context_text(problem, use_caption):
14
+ txt_context = problem['hint']
15
+ img_context = problem['caption'] if use_caption else ""
16
+ context = " ".join([txt_context, img_context]).strip()
17
+ if context == "":
18
+ context = "N/A"
19
+ return context
20
+
21
+
22
+ def get_choice_text(probelm, options):
23
+ choices = probelm['choices']
24
+ choice_list = []
25
+ for i, c in enumerate(choices):
26
+ choice_list.append("({}) {}".format(options[i], c))
27
+ choice_txt = " ".join(choice_list)
28
+ #print(choice_txt)
29
+ return choice_txt
30
+
31
+ def get_origin_answer(problem, options):
32
+ return problem['choices'][problem['answer']]
33
+
34
+ def get_answer(problem, options):
35
+ return options[problem['answer']]
36
+
37
+
38
+ def get_lecture_text(problem):
39
+ # \\n: GPT-3 can generate the lecture with more tokens.
40
+ lecture = problem['lecture'].replace("\n", "\\n")
41
+ return lecture
42
+
43
+
44
+ def get_solution_text(problem):
45
+ # \\n: GPT-3 can generate the solution with more tokens
46
+ solution = problem['solution'].replace("\n", "\\n")
47
+ return solution
48
+
49
+
50
+ def create_one_example(format, question, context, choice, answer, lecture, solution, test_example=True, WithOutput = False, curr_le_data=None):
51
+
52
+ input_format, output_format = format.split("-")
53
+
54
+ ## Inputs
55
+ if input_format == "CQM":
56
+ input = f"Context: {context}\nQuestion: {question}\nOptions: {choice}\n"
57
+ elif input_format == "QCM":
58
+ input = f"Question: {question}\nContext: {context}\nOptions: {choice}\n"
59
+ elif input_format == "QM":
60
+ input = f"Question: {question}\nOptions: {choice}\n"
61
+ elif input_format == "QC":
62
+ input = f"Question: {question}\nContext: {context}\n"
63
+ elif input_format == "QCMG":
64
+ if curr_le_data is not None:
65
+ input = f"Question: {question}\nContext: {context}\nOptions: {choice}\n{curr_le_data}\n"
66
+ else:
67
+ input = f"Question: {question}\nContext: {context}\nOptions: {choice}\nSolution: {lecture} {solution}\n"
68
+ elif input_format == "CQMG":
69
+ if curr_le_data is not None:
70
+ input = f"Context: {context}\nQuestion: {question}\nOptions: {choice}\n{curr_le_data}\n"
71
+ else:
72
+ input = f"Context: {context}\nQuestion: {question}\nOptions: {choice}\nSolution: {lecture} {solution}\n"
73
+ # upper bound experiment
74
+ elif input_format == "QCML":
75
+ input = f"Question: {question}\nContext: {context}\nOptions: {choice}\nBECAUSE: {lecture}\n"
76
+ elif input_format == "QCME":
77
+ input = f"Question: {question}\nContext: {context}\nOptions: {choice}\nBECAUSE: {solution}\n"
78
+ elif input_format == "QCMLE":
79
+ input = f"Question: {question}\nContext: {context}\nOptions: {choice}\nBECAUSE: {lecture} {solution}\n"
80
+
81
+ elif input_format == "QCLM":
82
+ input = f"Question: {question}\nContext: {context}\nBECAUSE: {lecture}\nOptions: {choice}\n"
83
+ elif input_format == "QCEM":
84
+ input = f"Question: {question}\nContext: {context}\nBECAUSE: {solution}\nOptions: {choice}\n"
85
+ elif input_format == "QCLEM":
86
+ input = f"Question: {question}\nContext: {context}\nBECAUSE: {lecture} {solution}\nOptions: {choice}\n"
87
+ elif input_format == "QCMA":
88
+ input = f"Question: {question}\nContext: {context}\nOptions: {choice}\nAnswer: The answer is {answer}.\n"
89
+ elif input_format == "QCA":
90
+ input = f"Question: {question}\nContext: {context}\nAnswer: The answer is {answer}. \nBECAUSE:"
91
+
92
+ # Outputs
93
+ if test_example:
94
+ if output_format == 'A':
95
+ output = "Answer:"
96
+ elif output_format == 'E':
97
+ output = "Solution:"
98
+ else:
99
+ output = "Solution:"
100
+ elif output_format == 'A':
101
+ output = f"Answer: The answer is {answer}."
102
+
103
+ elif output_format == 'AL':
104
+ output = f"Answer: The answer is {answer}. BECAUSE: {solution}"
105
+ elif output_format == 'AE':
106
+ output = f"Answer: The answer is {answer}. BECAUSE: {lecture}"
107
+ elif output_format == 'ALE':
108
+ output = f"Answer: The answer is {answer}. BECAUSE: {lecture} {solution}"
109
+ elif output_format == 'AEL':
110
+ output = f"Answer: The answer is {answer}. BECAUSE: {solution} {lecture}"
111
+
112
+ elif output_format == 'LA':
113
+ output = f"Answer: {lecture} The answer is {answer}."
114
+ elif output_format == 'EA':
115
+ output = f"Answer: {solution} The answer is {answer}."
116
+ elif output_format == 'LEA':
117
+ output = f"Answer: {lecture} {solution} The answer is {answer}."
118
+ elif output_format == 'ELA':
119
+ output = f"Answer: {solution} {lecture} The answer is {answer}."
120
+
121
+ elif output_format == 'LE':
122
+ output = f"Solution: {lecture} {solution}."
123
+
124
+ elif output_format == 'E':
125
+ output = f"Solution: {solution}"
126
+
127
+
128
+ if WithOutput:
129
+ if output.endswith("BECAUSE:"):
130
+ output = output.replace("BECAUSE:", "").strip()
131
+ if output_format == 'E':
132
+ text = input + f'Solution:'
133
+ elif output_format == 'A':
134
+ text = input + f'Answer:'
135
+ else:
136
+ text = input + f'Solution:'
137
+ text = text.replace(" ", " ").strip()
138
+ output = output.replace(" ", " ").strip()
139
+ return text, output
140
+
141
+
142
+ text = input + output
143
+ text = text.replace(" ", " ").strip()
144
+ if text.endswith("BECAUSE:"):
145
+ text = text.replace("BECAUSE:", "").strip()
146
+ return text
147
+
148
+
149
+ def build_prompt(problems, shot_qids, test_qid, args):
150
+
151
+ examples = []
152
+
153
+ # n-shot training examples
154
+ for qid in shot_qids:
155
+ question = get_question_text(problems[qid])
156
+ context = get_context_text(problems[qid], args.use_caption)
157
+ choice = get_choice_text(problems[qid], args.options)
158
+ answer = get_answer(problems[qid], args.options)
159
+ lecture = get_lecture_text(problems[qid])
160
+ solution = get_solution_text(problems[qid])
161
+
162
+ train_example = create_one_example(args.prompt_format,
163
+ question,
164
+ context,
165
+ choice,
166
+ answer,
167
+ lecture,
168
+ solution,
169
+ test_example=False)
170
+ examples.append(train_example)
171
+
172
+ # test example
173
+ question = get_question_text(problems[test_qid])
174
+ context = get_context_text(problems[test_qid], args.use_caption)
175
+ choice = get_choice_text(problems[test_qid], args.options)
176
+ answer = get_answer(problems[test_qid], args.options)
177
+ lecture = get_lecture_text(problems[test_qid])
178
+ solution = get_solution_text(problems[test_qid])
179
+
180
+ test_example = create_one_example(args.prompt_format,
181
+ question,
182
+ context,
183
+ choice,
184
+ answer,
185
+ lecture,
186
+ solution,
187
+ test_example=True)
188
+ examples.append(test_example)
189
+
190
+ # create the prompt input
191
+ prompt_input = '\n\n'.join(examples)
192
+
193
+ return prompt_input
194
+
195
+ def build_train_pair(problems, test_qid, args, curr_le_data=None):
196
+
197
+ examples = []
198
+
199
+ # test example
200
+ question = get_question_text(problems[test_qid])
201
+ context = get_context_text(problems[test_qid], args.use_caption)
202
+ choice = get_choice_text(problems[test_qid], args.options)
203
+
204
+ lecture = get_lecture_text(problems[test_qid])
205
+ solution = get_solution_text(problems[test_qid])
206
+
207
+ # answer_text = get_origin_answer(problems[test_qid], args.options)
208
+ answer_option = get_answer(problems[test_qid], args.options)
209
+ answer = "(" + answer_option + ")"
210
+
211
+ test_example, target = create_one_example(args.prompt_format,
212
+ question,
213
+ context,
214
+ choice,
215
+ answer,
216
+ lecture,
217
+ solution,
218
+ test_example=False,WithOutput = True, curr_le_data=curr_le_data)
219
+ examples.append(test_example)
220
+
221
+ target = target.replace("Answer:", "").strip()
222
+ # create the prompt input
223
+ prompt_input = '\n\n'.join(examples)
224
+
225
+ return prompt_input, target
226
+
227
+ @dataclass(frozen=True)
228
+ class InputFeatures:
229
+ """
230
+ A single set of features of data.
231
+ Property names are the same names as the corresponding inputs to a model.
232
+ """
233
+
234
+ input_ids: List[List[int]]
235
+ attention_mask: Optional[List[List[int]]]
236
+ token_type_ids: Optional[List[List[int]]]
237
+ le_input_ids: List[List[int]]
238
+ le_attention_mask: Optional[List[List[int]]]
239
+ le_token_type_ids: Optional[List[List[int]]]
240
+ label: Optional[int]
vision_features/mm-cot.png ADDED