Update run_eval.py
Browse files- run_eval.py +14 -1
run_eval.py
CHANGED
@@ -22,6 +22,7 @@ from logging import getLogger
|
|
22 |
from pathlib import Path
|
23 |
from typing import Dict, List
|
24 |
|
|
|
25 |
from tqdm import tqdm
|
26 |
|
27 |
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
|
@@ -33,12 +34,13 @@ from utils import (
|
|
33 |
use_task_specific_params,
|
34 |
)
|
35 |
|
|
|
36 |
|
37 |
|
38 |
logger = getLogger(__name__)
|
39 |
|
40 |
|
41 |
-
DEFAULT_DEVICE = "cpu"
|
42 |
|
43 |
|
44 |
def generate_summaries_or_translations(
|
@@ -206,6 +208,17 @@ def run_generate(
|
|
206 |
if scor_path:
|
207 |
args.score_path = scor_path
|
208 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
209 |
if parsed_args and verbose:
|
210 |
print(f"parsed the following generate kwargs: {parsed_args}")
|
211 |
examples = [
|
22 |
from pathlib import Path
|
23 |
from typing import Dict, List
|
24 |
|
25 |
+
import torch
|
26 |
from tqdm import tqdm
|
27 |
|
28 |
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
|
34 |
use_task_specific_params,
|
35 |
)
|
36 |
|
37 |
+
from evaluate_gpt import gpt_eval
|
38 |
|
39 |
|
40 |
logger = getLogger(__name__)
|
41 |
|
42 |
|
43 |
+
DEFAULT_DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
|
44 |
|
45 |
|
46 |
def generate_summaries_or_translations(
|
208 |
if scor_path:
|
209 |
args.score_path = scor_path
|
210 |
|
211 |
+
if args.model_name[-3:] == 'gpt':
|
212 |
+
gpt_eval(
|
213 |
+
model_name_path=args.model_name,
|
214 |
+
src_txt=args.input_path,
|
215 |
+
tar_txt=args.reference_path,
|
216 |
+
gen_path=args.save_path,
|
217 |
+
scor_path=args.score_path,
|
218 |
+
batch_size=args.bs
|
219 |
+
)
|
220 |
+
return None
|
221 |
+
|
222 |
if parsed_args and verbose:
|
223 |
print(f"parsed the following generate kwargs: {parsed_args}")
|
224 |
examples = [
|