|
|
|
import logging |
|
|
|
import argparse |
|
import pickle |
|
import csv |
|
import collections |
|
import itertools |
|
import copy |
|
from setup_database import get_document_store, add_data |
|
from setup_modules import create_retriever, create_readers_and_pipeline, text_reader_types, table_reader_types |
|
from eval_helper import create_labels |
|
|
|
def parse_args(): |
|
parser = argparse.ArgumentParser(description="JointXplore") |
|
|
|
parser.add_argument("--context", help='which information should be added as context, subset of [processed_website_tables, processed_website_text, processed_schedule_tables], enter as multiple strings', |
|
nargs='+', default=["processed_website_tables","processed_website_text","processed_schedule_tables"]) |
|
parser.add_argument("--text_reader", help="specify the model to use as text reader", choices=["minilm", "distilroberta", "electra-base", "bert-base", "deberta-large", "gpt3"], default="bert-base") |
|
parser.add_argument("--api-key", help="if gpt3 choosen as reader, please provide api-key", action="store_true") |
|
parser.add_argument("--table_reader", help="choose tapas or convert table to text file and treat them as such", choices=["tapas", "text"], default="tapas") |
|
parser.add_argument("--seperate_evaluation", help="if specified, student generated questions and synthetically generated questions are evaluated seperately", action="store_true") |
|
args = parser.parse_args() |
|
|
|
|
|
|
|
return args |
|
|
|
def main(*args): |
|
if args=={}: |
|
args = parse_args() |
|
filenames = args.context |
|
text_reader = args.text_reader |
|
table_reader = args.table_reader |
|
seperate_evaluation = args.seperate_evaluation |
|
else: |
|
filenames, text_reader, table_reader, seperate_evaluation = args |
|
print(f"Filenames: {filenames}") |
|
use_table = False |
|
use_text = False |
|
if "processed_schedule_tables" in filenames: |
|
use_table = True |
|
if "processed_website_text" or "processed_website_tables" in filenames: |
|
use_text = True |
|
|
|
|
|
logging.basicConfig(format="%(levelname)s - %(name)s - %(message)s", level=logging.WARNING) |
|
logging.getLogger("haystack").setLevel(logging.WARNING) |
|
logging.info("Starting..") |
|
document_index = "document" |
|
document_store = get_document_store(document_index) |
|
print(f"Number of docs previously: {len(document_store.get_all_documents())}") |
|
document_store, data = add_data(filenames, document_store, document_index) |
|
print(f"Number of docs after: {len(document_store.get_all_documents())}") |
|
document_store, retriever = create_retriever(document_store) |
|
text_reader_type = text_reader_types[text_reader] |
|
table_reader_type = table_reader_types[table_reader] |
|
pipeline = create_readers_and_pipeline(retriever, text_reader_type, table_reader_type, use_table, use_text) |
|
|
|
with open("./output/results.csv", "r") as f: |
|
reader = csv.reader(f) |
|
for header in reader: |
|
break |
|
|
|
labels_file = "./data/validation_data/processed_qa.json" |
|
labels = create_labels(labels_file, data, seperate_evaluation) |
|
label_types = ["all_eval"] |
|
if seperate_evaluation: |
|
label_types = ["students", "synthetic"] |
|
for idx, label in enumerate(labels): |
|
|
|
|
|
|
|
|
|
|
|
print(f"Label Dataset: {idx}") |
|
results = pipeline.eval(label, params={"top_k": 10}, sas_model_name_or_path="cross-encoder/stsb-roberta-large") |
|
res_dict = results.calculate_metrics() |
|
print(res_dict) |
|
with open(f"./output/{text_reader}_{table_reader}_{('_').join(filenames)}_{label_types[idx]}", "wb") as fp: |
|
pickle.dump(results, fp) |
|
exp_dict = { |
|
"Text Reader": text_reader, |
|
"Table Reader": table_reader, |
|
"Context" : ('_').join(filenames), |
|
"Label type": label_types[idx] |
|
} |
|
if 'JoinAnswers' in res_dict: |
|
csv_dict_new = {**res_dict['EmbeddingRetriever'], **res_dict['JoinAnswers'], **exp_dict} |
|
elif 'TableReader' in res_dict: |
|
csv_dict_new = {**res_dict['EmbeddingRetriever'], **res_dict['TableReader'], **exp_dict} |
|
elif 'TextReader' in res_dict: |
|
csv_dict_new = {**res_dict['EmbeddingRetriever'], **res_dict['TextReader'], **exp_dict} |
|
if idx == 1: |
|
csv_dict_all = {} |
|
|
|
total_num_samples = csv_dict["num_examples_for_eval"] + csv_dict_new["num_examples_for_eval"] |
|
weight_old = csv_dict["num_examples_for_eval"] |
|
weight_new = csv_dict_new["num_examples_for_eval"] |
|
print("Weights for datasets:", weight_old, weight_new) |
|
print("new") |
|
for key, val in csv_dict.items(): |
|
if not isinstance(val, str): |
|
if key != "num_examples_for_eval": |
|
csv_dict_all[key] = ((val*weight_old + csv_dict_new[key]*weight_new)/total_num_samples) |
|
else: |
|
csv_dict_all[key] = (val + csv_dict_new[key]) |
|
else: |
|
csv_dict_all[key] = val |
|
csv_dict_all["Label type"] = "all_eval" |
|
with open("./output/results.csv", "a", newline='') as f: |
|
writer = csv.DictWriter(f, fieldnames=header) |
|
writer.writerow(csv_dict_all) |
|
csv_dict = copy.deepcopy(csv_dict_new) |
|
print(csv_dict) |
|
|
|
with open("./output/results.csv", "a", newline='') as f: |
|
writer = csv.DictWriter(f, fieldnames=header) |
|
writer.writerow(csv_dict) |
|
|
|
document_store.delete_index(document_index) |
|
|
|
|
|
if __name__ == "__main__": |
|
main() |