|
|
|
|
|
import sys |
|
import os |
|
sys.path.append("..") |
|
|
|
import re |
|
import json |
|
import fire |
|
import string |
|
|
|
from tqdm.autonotebook import tqdm |
|
from medalpaca.inferer import Inferer |
|
|
|
|
|
greedy_search = { |
|
"num_beams" : 1, |
|
"do_sample" : False, |
|
"max_new_tokens" : 128, |
|
"early_stopping" : False |
|
} |
|
|
|
beam_serach = { |
|
"num_beams" : 4, |
|
"do_sample" : False, |
|
"max_new_tokens" : 128, |
|
"early_stopping" : True, |
|
} |
|
|
|
sampling_top_k = { |
|
"do_sample" : True, |
|
"num_beams": 1, |
|
"max_new_tokens": 128, |
|
"early_stopping": True, |
|
"temperature": 0.7, |
|
"top_k": 50 |
|
} |
|
|
|
sampling_top_p = { |
|
"do_sample" : True, |
|
"top_k": 0, |
|
"num_beams": 1, |
|
"max_new_tokens": 128, |
|
"early_stopping": True, |
|
"temperature": 0.7, |
|
"top_p": 0.9 |
|
} |
|
|
|
sampling = { |
|
"do_sample" : True, |
|
"top_k": 50, |
|
"num_beams": 1, |
|
"max_new_tokens": 128, |
|
"early_stopping": True, |
|
"temperature": 0.4, |
|
"top_p": 0.9 |
|
} |
|
|
|
|
|
def format_question(d): |
|
question = d["question"] |
|
options = d["options"] |
|
for k, v in options.items(): |
|
question += f"\n{k}: {v}" |
|
return question |
|
|
|
|
|
def strip_special_chars(input_str): |
|
"Remove special characters from string start/end" |
|
if not input_str: |
|
return input_str |
|
|
|
start_index = 0 |
|
end_index = len(input_str) - 1 |
|
|
|
while start_index < len(input_str) and input_str[start_index] not in string.ascii_letters + string.digits: |
|
start_index += 1 |
|
|
|
while end_index >= 0 and input_str[end_index] not in string.ascii_letters + string.digits: |
|
end_index -= 1 |
|
|
|
if start_index <= end_index: |
|
return input_str[start_index:end_index + 1] |
|
else: |
|
return "" |
|
|
|
def starts_with_capital_letter(input_str): |
|
""" |
|
The answers should start like this: |
|
'A: ' |
|
'A. ' |
|
'A ' |
|
""" |
|
pattern = r'^[A-Z](:|\.|) .+' |
|
return bool(re.match(pattern, input_str)) |
|
|
|
|
|
def main( |
|
model_name: str, |
|
prompt_template: str, |
|
base_model: str, |
|
peft: bool, |
|
load_in_8bit: bool, |
|
path_to_exams: str, |
|
ntries: int = 5, |
|
skip_if_exists: bool = True, |
|
): |
|
|
|
model = Inferer( |
|
model_name=model_name, |
|
prompt_template=prompt_template, |
|
base_model=base_model, |
|
peft=peft, |
|
load_in_8bit=load_in_8bit, |
|
) |
|
|
|
|
|
|
|
|
|
|
|
|
|
with open(os.path.join(path_to_exams, f"test.jsonl")) as f: |
|
questions = [json.loads(line) for line in f] |
|
print(questions) |
|
|
|
outname = os.path.join(path_to_exams, f"step_{model_name.split('/')[-1]}.json") |
|
if os.path.exists(outname): |
|
with open(outname, "r") as fp: |
|
answers = json.load(fp) |
|
else: |
|
answers = [] |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
print(len(questions)) |
|
for question in tqdm(questions): |
|
question = question |
|
n = 0 |
|
response = model( |
|
instruction="Answer this multiple choice question.", |
|
input=format_question(question), |
|
output="The Answer to the question is:", |
|
**sampling |
|
) |
|
response = strip_special_chars(response) |
|
print(response) |
|
if starts_with_capital_letter(response): |
|
n += 1 |
|
break |
|
else: |
|
print(f"Output not satisfactoy, retrying {n+1}/{ntries}") |
|
question["answer"] = response |
|
answers.append(question) |
|
|
|
with open(outname, "w+") as fp: |
|
json.dump(answers, fp) |
|
|
|
|
|
if __name__ == "__main__": |
|
fire.Fire(main) |
|
|