dmahata commited on
Commit
70b2a4a
1 Parent(s): b360761

Update run_eval.py

Browse files
Files changed (1) hide show
  1. run_eval.py +14 -1
run_eval.py CHANGED
@@ -22,6 +22,7 @@ from logging import getLogger
22
  from pathlib import Path
23
  from typing import Dict, List
24
 
 
25
  from tqdm import tqdm
26
 
27
  from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
@@ -33,12 +34,13 @@ from utils import (
33
  use_task_specific_params,
34
  )
35
 
 
36
 
37
 
38
  logger = getLogger(__name__)
39
 
40
 
41
- DEFAULT_DEVICE = "cpu"
42
 
43
 
44
  def generate_summaries_or_translations(
@@ -206,6 +208,17 @@ def run_generate(
206
  if scor_path:
207
  args.score_path = scor_path
208
 
 
 
 
 
 
 
 
 
 
 
 
209
  if parsed_args and verbose:
210
  print(f"parsed the following generate kwargs: {parsed_args}")
211
  examples = [
22
  from pathlib import Path
23
  from typing import Dict, List
24
 
25
+ import torch
26
  from tqdm import tqdm
27
 
28
  from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
34
  use_task_specific_params,
35
  )
36
 
37
+ from evaluate_gpt import gpt_eval
38
 
39
 
40
  logger = getLogger(__name__)
41
 
42
 
43
+ DEFAULT_DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
44
 
45
 
46
  def generate_summaries_or_translations(
208
  if scor_path:
209
  args.score_path = scor_path
210
 
211
+ if args.model_name[-3:] == 'gpt':
212
+ gpt_eval(
213
+ model_name_path=args.model_name,
214
+ src_txt=args.input_path,
215
+ tar_txt=args.reference_path,
216
+ gen_path=args.save_path,
217
+ scor_path=args.score_path,
218
+ batch_size=args.bs
219
+ )
220
+ return None
221
+
222
  if parsed_args and verbose:
223
  print(f"parsed the following generate kwargs: {parsed_args}")
224
  examples = [