BounWiki / eval_helper.py
LeoGitGuy
added files
9bf0a0f
raw
history blame
2.87 kB
from haystack import Label, MultiLabel, Answer
import json
import re
def read_labels(labels, tables):
processed_labels = []
for table in tables:
if table.id not in labels:
continue
doc_labels = labels[table.id]
for label in doc_labels:
label = Label(
query=label["question"],
document=table,
is_correct_answer=True,
is_correct_document=True,
answer=Answer(answer=label["answers"][0]["text"]),
origin="gold-label",
)
processed_labels.append(MultiLabel(labels=[label]))
return processed_labels
def create_labels(labels_file, data, seperate_eval):
eval_labels = []
with open(labels_file) as labels_file:
labels = json.load(labels_file)
if seperate_eval:
use_labels = filter_labels(labels)
else:
use_labels = [labels]
for l in use_labels:
labels = []
for d in data:
labels += read_labels(l, d)
print(f"Number of Labels: {len(labels)}")
eval_labels.append(labels)
return eval_labels
def get_processed_squad_labels(squad_labels):
with open(f'./data/validation_data/{squad_labels}') as fp:
squad_labels = json.load(fp)
# Process Squad File by aligning the right document IDs for the course schedules
processed_squad_labels = {}
for paragraph in squad_labels["data"]:
context = paragraph["paragraphs"][0]["context"]
if context[:43] == "Code\tName\tEcts\tInstructor\tDays\tHours\tRooms\n":
faculty_abb = re.search(r"[a-z]*", context[43:], re.IGNORECASE).group()
if faculty_abb in processed_squad_labels:
processed_squad_labels[faculty_abb].extend(paragraph["paragraphs"][0]["qas"])
else:
processed_squad_labels[faculty_abb] = paragraph["paragraphs"][0]["qas"]
else:
processed_squad_labels[str(paragraph["paragraphs"][0]["document_id"])] = paragraph["paragraphs"][0]["qas"]
with open("./data/validation_data/processed_qa.json", "w") as outfile:
json.dump(processed_squad_labels, outfile)
#return processed_squad_labels
def filter_labels(labels):
with open("./data/validation_data/questions_new.txt", "r") as fp:
user_questions = fp.read()
user_questions = user_questions.split("\n")
user_questions = [qu.strip() for qu in user_questions]
user_squad_labels = {}
synthetic_squad_labels = {}
for doc, questions in labels.items():
for q in questions:
if q["question"].strip() in user_questions:
if doc in user_squad_labels:
user_squad_labels[doc].append(q)
else:
user_squad_labels[doc] = [q]
else:
if doc in synthetic_squad_labels:
synthetic_squad_labels[doc].append(q)
else:
synthetic_squad_labels[doc] = [q]
return [user_squad_labels, synthetic_squad_labels]