File size: 5,627 Bytes
af461f3 b06298d af461f3 51a31d4 f2e3e47 b06298d 51dabd6 af461f3 ab5dfc2 b06298d af461f3 ab5dfc2 51a31d4 325e3c6 b06298d ab5dfc2 f2e3e47 ab5dfc2 51dabd6 51a31d4 e9df5ab 51a31d4 b06298d b7158e7 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 |
import random
from typing import Dict, cast
import torch
import transformers
from datasets import DatasetDict, load_dataset
from dotenv import load_dotenv
from query import print_answers
from src.evaluation import evaluate
from src.readers.dpr_reader import DprReader
from src.retrievers.base_retriever import Retriever
from src.retrievers.es_retriever import ESRetriever
from src.retrievers.faiss_retriever import FaissRetriever
from src.utils.log import get_logger
from src.utils.preprocessing import context_to_reader_input
from src.utils.timing import get_times, timeit
logger = get_logger()
load_dotenv()
transformers.logging.set_verbosity_error()
if __name__ == '__main__':
dataset_name = "GroNLP/ik-nlp-22_slp"
paragraphs = cast(DatasetDict, load_dataset(
"GroNLP/ik-nlp-22_slp", "paragraphs"))
questions = cast(DatasetDict, load_dataset(dataset_name, "questions"))
# Only doing a few questions for speed
subset_idx = 3
questions_test = questions["test"][:subset_idx]
experiments: Dict[str, Retriever] = {
"faiss": FaissRetriever(paragraphs),
# "es": ESRetriever(paragraphs),
}
for experiment_name, retriever in experiments.items():
reader = DprReader()
for idx in range(subset_idx):
question = questions_test["question"][idx]
answer = questions_test["answer"][idx]
scores, context = retriever.retrieve(question, 5)
reader_input = context_to_reader_input(context)
# workaround so we can use the decorator with a dynamic name for time recording
time_wrapper = timeit(f"{experiment_name}.read")
answers = time_wrapper(reader.read)(question, reader_input, 5)
# Calculate softmaxed scores for readable output
sm = torch.nn.Softmax(dim=0)
document_scores = sm(torch.Tensor(
[pred.relevance_score for pred in answers]))
span_scores = sm(torch.Tensor(
[pred.span_score for pred in answers]))
print_answers(answers, scores, context)
# TODO evaluation and storing of results
times = get_times()
print(times)
# TODO evaluation and storing of results
# # Initialize retriever
# retriever = FaissRetriever(paragraphs)
# # retriever = ESRetriever(paragraphs)
# # Retrieve example
# # random.seed(111)
# random_index = random.randint(0, len(questions_test["question"])-1)
# example_q = questions_test["question"][random_index]
# example_a = questions_test["answer"][random_index]
# scores, result = retriever.retrieve(example_q)
# reader_input = context_to_reader_input(result)
# # TODO: use new code from query.py to clean this up
# # Initialize reader
# answers = reader.read(example_q, reader_input)
# # Calculate softmaxed scores for readable output
# sm = torch.nn.Softmax(dim=0)
# document_scores = sm(torch.Tensor(
# [pred.relevance_score for pred in answers]))
# span_scores = sm(torch.Tensor(
# [pred.span_score for pred in answers]))
# print(example_q)
# for answer_i, answer in enumerate(answers):
# print(f"[{answer_i + 1}]: {answer.text}")
# print(f"\tDocument {answer.doc_id}", end='')
# print(f"\t(score {document_scores[answer_i] * 100:.02f})")
# print(f"\tSpan {answer.start_index}-{answer.end_index}", end='')
# print(f"\t(score {span_scores[answer_i] * 100:.02f})")
# print() # Newline
# # print(f"Example q: {example_q} answer: {result['text'][0]}")
# # for i, score in enumerate(scores):
# # print(f"Result {i+1} (score: {score:.02f}):")
# # print(result['text'][i])
# # Determine best answer we want to evaluate
# highest, highest_index = 0, 0
# for i, value in enumerate(span_scores):
# if value + document_scores[i] > highest:
# highest = value + document_scores[i]
# highest_index = i
# # Retrieve exact match and F1-score
# exact_match, f1_score = evaluate(
# example_a, answers[highest_index].text)
# print(f"Gold answer: {example_a}\n"
# f"Predicted answer: {answers[highest_index].text}\n"
# f"Exact match: {exact_match:.02f}\n"
# f"F1-score: {f1_score:.02f}")
# Calculate overall performance
# total_f1 = 0
# total_exact = 0
# total_len = len(questions_test["question"])
# start_time = time.time()
# for i, question in enumerate(questions_test["question"]):
# print(question)
# answer = questions_test["answer"][i]
# print(answer)
#
# scores, result = retriever.retrieve(question)
# reader_input = result_to_reader_input(result)
# answers = reader.read(question, reader_input)
#
# document_scores = sm(torch.Tensor(
# [pred.relevance_score for pred in answers]))
# span_scores = sm(torch.Tensor(
# [pred.span_score for pred in answers]))
#
# highest, highest_index = 0, 0
# for j, value in enumerate(span_scores):
# if value + document_scores[j] > highest:
# highest = value + document_scores[j]
# highest_index = j
# print(answers[highest_index])
# exact_match, f1_score = evaluate(answer, answers[highest_index].text)
# total_f1 += f1_score
# total_exact += exact_match
# print(f"Total time:", round(time.time() - start_time, 2), "seconds.")
# print(total_f1)
# print(total_exact)
# print(total_f1/total_len)
|