dmahata commited on
Commit
fd40c9d
1 Parent(s): 42f951a

Upload run_eval.py

Browse files
Files changed (1) hide show
  1. run_eval.py +282 -0
run_eval.py ADDED
@@ -0,0 +1,282 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+ # Copyright 2020 The HuggingFace Team. All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ import argparse
17
+ import datetime
18
+ import json
19
+ import time
20
+ import warnings
21
+ from logging import getLogger
22
+ from pathlib import Path
23
+ from typing import Dict, List
24
+
25
+ import torch
26
+ from tqdm import tqdm
27
+
28
+ from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
29
+ from utils import (
30
+ calculate_bleu,
31
+ calculate_rouge,
32
+ chunks,
33
+ parse_numeric_n_bool_cl_kwargs,
34
+ use_task_specific_params,
35
+ )
36
+
37
+ from evaluate_gpt import gpt_eval
38
+
39
+
40
+ logger = getLogger(__name__)
41
+
42
+
43
+ DEFAULT_DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
44
+
45
+
46
+ def generate_summaries_or_translations(
47
+ examples: List[str],
48
+ out_file: str,
49
+ model_name: str,
50
+ batch_size: int = 8,
51
+ device: str = DEFAULT_DEVICE,
52
+ fp16=False,
53
+ task="summarization",
54
+ prefix=None,
55
+ **generate_kwargs,
56
+ ) -> Dict:
57
+ """Save model.generate results to <out_file>, and return how long it took."""
58
+ fout = Path(out_file).open("w", encoding="utf-8")
59
+ model_name = str(model_name)
60
+ model = AutoModelForSeq2SeqLM.from_pretrained(model_name).to(device)
61
+ if fp16:
62
+ model = model.half()
63
+
64
+ tokenizer = AutoTokenizer.from_pretrained(model_name)
65
+ logger.info(
66
+ f"Inferred tokenizer type: {tokenizer.__class__}"
67
+ ) # if this is wrong, check config.model_type.
68
+
69
+ start_time = time.time()
70
+ # update config with task specific params
71
+ use_task_specific_params(model, task)
72
+ if prefix is None:
73
+ prefix = prefix or getattr(model.config, "prefix", "") or ""
74
+ for examples_chunk in tqdm(list(chunks(examples, batch_size))):
75
+ examples_chunk = [prefix + text for text in examples_chunk]
76
+ batch = tokenizer(
77
+ examples_chunk, return_tensors="pt", truncation=True, padding="longest"
78
+ ).to(device)
79
+ summaries = model.generate(
80
+ input_ids=batch.input_ids,
81
+ attention_mask=batch.attention_mask,
82
+ **generate_kwargs,
83
+ )
84
+ dec = tokenizer.batch_decode(
85
+ summaries, skip_special_tokens=True, clean_up_tokenization_spaces=False
86
+ )
87
+ for hypothesis in dec:
88
+ fout.write(hypothesis + "\n")
89
+ fout.flush()
90
+ fout.close()
91
+ runtime = int(time.time() - start_time) # seconds
92
+ n_obs = len(examples)
93
+ return dict(
94
+ n_obs=n_obs, runtime=runtime, seconds_per_sample=round(runtime / n_obs, 4)
95
+ )
96
+
97
+
98
+ def datetime_now():
99
+ return datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S")
100
+
101
+
102
+ def run_generate(
103
+ verbose=True,
104
+ model_name_path=None,
105
+ src_txt=None,
106
+ tar_txt=None,
107
+ gen_path=None,
108
+ scor_path=None,
109
+ batch_size=None,
110
+ ):
111
+ """
112
+
113
+ Takes input text, generates output, and then using reference calculates the BLEU scores.
114
+
115
+ The results are saved to a file and returned to the caller, and printed out unless ``verbose=False`` is passed.
116
+
117
+ Args:
118
+ verbose (:obj:`bool`, `optional`, defaults to :obj:`True`): print results to stdout
119
+
120
+ Returns:
121
+ a tuple: ``(scores, params}``
122
+ - ``scores``: a dict of scores data ``{'bleu': 39.6501, 'n_obs': 2000, 'runtime': 186, 'seconds_per_sample': 0.093}``
123
+ - ``params``: a dict of custom params, e.g. ``{'num_beams': 5, 'length_penalty': 0.8}``
124
+ """
125
+
126
+ parser = argparse.ArgumentParser()
127
+ parser.add_argument(
128
+ "--model_name",
129
+ type=str,
130
+ required=False,
131
+ help="like facebook/bart-large-cnn,t5-base, etc.",
132
+ )
133
+ parser.add_argument(
134
+ "--input_path", type=str, required=False, help="like cnn_dm/test.source"
135
+ )
136
+ parser.add_argument(
137
+ "--save_path", type=str, required=False, help="where to save summaries"
138
+ )
139
+ parser.add_argument(
140
+ "--reference_path", type=str, required=False, help="like cnn_dm/test.target"
141
+ )
142
+ parser.add_argument(
143
+ "--score_path",
144
+ type=str,
145
+ required=False,
146
+ default="metrics.json",
147
+ help="where to save metrics",
148
+ )
149
+ parser.add_argument(
150
+ "--device",
151
+ type=str,
152
+ required=False,
153
+ default=DEFAULT_DEVICE,
154
+ help="cuda, cuda:1, cpu etc.",
155
+ )
156
+ parser.add_argument(
157
+ "--prefix",
158
+ type=str,
159
+ required=False,
160
+ default=None,
161
+ help="will be added to the begininng of src examples",
162
+ )
163
+ parser.add_argument(
164
+ "--task",
165
+ type=str,
166
+ default="summarization",
167
+ help="used for task_specific_params + metrics",
168
+ )
169
+ parser.add_argument("--bs", type=int, default=8, required=False, help="batch size")
170
+ parser.add_argument(
171
+ "--n_obs",
172
+ type=int,
173
+ default=-1,
174
+ required=False,
175
+ help="How many observations. Defaults to all.",
176
+ )
177
+ parser.add_argument("--fp16", action="store_true")
178
+ parser.add_argument(
179
+ "--dump-args",
180
+ action="store_true",
181
+ help="print the custom hparams with the results",
182
+ )
183
+ parser.add_argument(
184
+ "--info",
185
+ nargs="?",
186
+ type=str,
187
+ const=datetime_now(),
188
+ help="use in conjunction w/ --dump-args to print with the results whatever other info you'd like, e.g. lang=en-ru. If no value is passed, the current datetime string will be used.",
189
+ )
190
+ # Unspecified args like --num_beams=2 --decoder_start_token_id=4 are passed to model.generate
191
+ args, rest = parser.parse_known_args()
192
+ parsed_args = parse_numeric_n_bool_cl_kwargs(rest)
193
+ if model_name_path:
194
+ args.model_name = model_name_path
195
+
196
+ if src_txt:
197
+ args.input_path = src_txt
198
+
199
+ if tar_txt:
200
+ args.reference_path = tar_txt
201
+
202
+ if batch_size:
203
+ args.bs = batch_size
204
+
205
+ if gen_path:
206
+ args.save_path = gen_path
207
+
208
+ if scor_path:
209
+ args.score_path = scor_path
210
+
211
+ if args.model_name[-3:] == 'gpt':
212
+ gpt_eval(
213
+ model_name_path=args.model_name,
214
+ src_txt=args.input_path,
215
+ tar_txt=args.reference_path,
216
+ gen_path=args.save_path,
217
+ scor_path=args.score_path,
218
+ batch_size=args.bs
219
+ )
220
+ return None
221
+
222
+ if parsed_args and verbose:
223
+ print(f"parsed the following generate kwargs: {parsed_args}")
224
+ examples = [
225
+ " " + x.rstrip() if "t5" in args.model_name else x.rstrip()
226
+ for x in open(args.input_path).readlines()
227
+ ]
228
+ if args.n_obs > 0:
229
+ examples = examples[: args.n_obs]
230
+ Path(args.save_path).parent.mkdir(exist_ok=True)
231
+
232
+ if args.reference_path is None and Path(args.score_path).exists():
233
+ warnings.warn(
234
+ f"score_path {args.score_path} will be overwritten unless you type ctrl-c."
235
+ )
236
+
237
+ if args.device == "cpu" and args.fp16:
238
+ # this mix leads to RuntimeError: "threshold_cpu" not implemented for 'Half'
239
+ raise ValueError("Can't mix --fp16 and --device cpu")
240
+
241
+ runtime_metrics = generate_summaries_or_translations(
242
+ examples,
243
+ args.save_path,
244
+ args.model_name,
245
+ batch_size=args.bs,
246
+ device=args.device,
247
+ fp16=args.fp16,
248
+ task=args.task,
249
+ prefix=args.prefix,
250
+ **parsed_args,
251
+ )
252
+
253
+ if args.reference_path is None:
254
+ return {}
255
+
256
+ # Compute scores
257
+ score_fn = calculate_bleu if "translation" in args.task else calculate_rouge
258
+ output_lns = [x.rstrip() for x in open(args.save_path).readlines()]
259
+ reference_lns = [x.rstrip() for x in open(args.reference_path).readlines()][
260
+ : len(output_lns)
261
+ ]
262
+ scores: dict = score_fn(output_lns, reference_lns)
263
+ scores.update(runtime_metrics)
264
+
265
+ if args.dump_args:
266
+ scores.update(parsed_args)
267
+ if args.info:
268
+ scores["info"] = args.info
269
+
270
+ if verbose:
271
+ print(scores)
272
+
273
+ if args.score_path is not None:
274
+ json.dump(scores, open(args.score_path, "w"))
275
+
276
+ return scores
277
+
278
+
279
+ if __name__ == "__main__":
280
+ # Usage for MT:
281
+ # python run_eval.py MODEL_NAME $DATA_DIR/test.source $save_dir/test_translations.txt --reference_path $DATA_DIR/test.target --score_path $save_dir/test_bleu.json --task translation $@
282
+ run_generate(verbose=True)