auto-info / PSV /utils /game24 /decoding.py
rookiemango's picture
Upload folder using huggingface_hub
da66274 verified
raw
history blame contribute delete
802 Bytes
import re
import sympy
from typing import List
def extract_expression(response: str):
return response.strip().split('\n')[-1].lower().split('the answer is ')[-1].split('=')[0].strip()
def extract_expressions(responses: List[str]):
return [extract_expression(response) for response in responses]
# refer to https://github.com/princeton-nlp/tree-of-thought-llm/blob/master/src/tot/tasks/game24.py
def get_answer_label(expression: str, question: str):
numbers = re.findall(r'\d+', expression)
problem_numbers = re.findall(r'\d+', question)
if sorted(numbers) != sorted(problem_numbers):
return False
try:
# print(sympy.simplify(expression))
return sympy.simplify(expression) == 24
except Exception as e:
# print(e)
return False