DAN_AI / prompt.py
oliverwang15's picture
updates on the submit button
772f8cb
from template import TEMPLATE_v1, TEMPLATE_v2, TEMPLATE_v3, QUESTIONS
import json
class Prompt:
def __init__(self) -> None:
# self.questions = QUESTIONS
self.template_v1 = TEMPLATE_v1
self.template_v2 = TEMPLATE_v2
self.template_v3 = TEMPLATE_v3
self.version = "v3"
def combine_questions(self, questions):
questions = [ f'Question {id_ +1 }: {q}' for id_, q in enumerate(questions) if 'Input question' not in q]
questions = '\n'.join(questions)
return questions
def _get_v1(self, input, questions):
questions = self.combine_questions(questions)
return self.template_v1.format(input, self.questions)
def _get_v2(self, input, questions):
questions = self.combine_questions(questions)
return self.template_v2.format(input, self.questions)
def _get_v3(self, input, questions):
return self.template_v3.format(input)
def get(self, input, questions, version = None):
self.version = version if version else self.version
if self.version == 'v1':
return self._get_v1(input, questions)
elif self.version == 'v2':
return self._get_v2(input, questions)
elif self.version == 'v3':
return self._get_v3(input, questions)
else:
raise ValueError('Version should be one of {v1, v2, v3}')
def _process_v1(self, res):
res = json.loads(res)
return res
def _process_v2(self, res):
res = json.loads(res)
return res
def _process_v3(self, x):
x = json.loads(x)
res = {}
question_id = 0
for k, v in x.items():
if 'answer' in v:
question_id += 1
question_name = f'Question {question_id}'
res_tmp = {"answer": v['answer'], "original sentences": v['original sentences']}
res[question_name] = res_tmp
else:
k_1, k_2 = v.keys()
in_1 = v[k_1]
in_2 = v[k_2]
question_id += 1
question_name = f'Question {question_id}'
res_tmp_1 = {"answer": in_1['answer'], "original sentences": in_1['original sentences']}
res[question_name] = res_tmp_1
question_id += 1
question_name = f'Question {question_id}'
res_tmp_2 = {"answer": in_2['answer'], "original sentences": in_2['original sentences']}
res[question_name] = res_tmp_2
return res
def process_result(self, result, version = None):
if not version is None and self.version != version:
self.version = version
print(f'Version changed to {version}')
if version == 'v1':
result = self._process_v1(result)
return result
elif version == 'v2':
result = self._process_v2(result)
return result
elif version == 'v3':
result = self._process_v3(result)
return result
else:
raise ValueError('Version should be one of {v1, v2, v3}')