# imports import logging #import torch_scatter import argparse import pickle import csv import collections import itertools import copy from setup_database import get_document_store, add_data from setup_modules import create_retriever, create_readers_and_pipeline, text_reader_types, table_reader_types from eval_helper import create_labels def parse_args(): parser = argparse.ArgumentParser(description="JointXplore") parser.add_argument("--context", help='which information should be added as context, subset of [processed_website_tables, processed_website_text, processed_schedule_tables], enter as multiple strings', nargs='+', default=["processed_website_tables","processed_website_text","processed_schedule_tables"]) parser.add_argument("--text_reader", help="specify the model to use as text reader", choices=["minilm", "distilroberta", "electra-base", "bert-base", "deberta-large", "gpt3"], default="bert-base") parser.add_argument("--api-key", help="if gpt3 choosen as reader, please provide api-key", action="store_true") parser.add_argument("--table_reader", help="choose tapas or convert table to text file and treat them as such", choices=["tapas", "text"], default="tapas") parser.add_argument("--seperate_evaluation", help="if specified, student generated questions and synthetically generated questions are evaluated seperately", action="store_true") args = parser.parse_args() # if 'LOCAL_RANK' not in os.environ: # os.environ['LOCAL_RANK'] = str(args.local_rank) return args def main(*args): if args=={}: args = parse_args() filenames = args.context text_reader = args.text_reader table_reader = args.table_reader seperate_evaluation = args.seperate_evaluation else: filenames, text_reader, table_reader, seperate_evaluation = args print(f"Filenames: {filenames}") use_table = False use_text = False if "processed_schedule_tables" in filenames: use_table = True if "processed_website_text" or "processed_website_tables" in filenames: use_text = True # configure logger logging.basicConfig(format="%(levelname)s - %(name)s - %(message)s", level=logging.WARNING) logging.getLogger("haystack").setLevel(logging.WARNING) logging.info("Starting..") document_index = "document" document_store = get_document_store(document_index) print(f"Number of docs previously: {len(document_store.get_all_documents())}") document_store, data = add_data(filenames, document_store, document_index) print(f"Number of docs after: {len(document_store.get_all_documents())}") document_store, retriever = create_retriever(document_store) text_reader_type = text_reader_types[text_reader] table_reader_type = table_reader_types[table_reader] pipeline = create_readers_and_pipeline(retriever, text_reader_type, table_reader_type, use_table, use_text) with open("./output/results.csv", "r") as f: reader = csv.reader(f) for header in reader: break labels_file = "./data/validation_data/processed_qa.json" labels = create_labels(labels_file, data, seperate_evaluation) label_types = ["all_eval"] if seperate_evaluation: label_types = ["students", "synthetic"] for idx, label in enumerate(labels): # for la in label: # for l in la.labels: # print("CHECK WRONG DOC") # print(l.document.content == "") # print(l.document.id) print(f"Label Dataset: {idx}") results = pipeline.eval(label, params={"top_k": 10}, sas_model_name_or_path="cross-encoder/stsb-roberta-large") res_dict = results.calculate_metrics() print(res_dict) with open(f"./output/{text_reader}_{table_reader}_{('_').join(filenames)}_{label_types[idx]}", "wb") as fp: pickle.dump(results, fp) exp_dict = { "Text Reader": text_reader, "Table Reader": table_reader, "Context" : ('_').join(filenames), "Label type": label_types[idx] } if 'JoinAnswers' in res_dict: csv_dict_new = {**res_dict['EmbeddingRetriever'], **res_dict['JoinAnswers'], **exp_dict} elif 'TableReader' in res_dict: csv_dict_new = {**res_dict['EmbeddingRetriever'], **res_dict['TableReader'], **exp_dict} elif 'TextReader' in res_dict: csv_dict_new = {**res_dict['EmbeddingRetriever'], **res_dict['TextReader'], **exp_dict} if idx == 1: csv_dict_all = {} # iterating key, val with chain() total_num_samples = csv_dict["num_examples_for_eval"] + csv_dict_new["num_examples_for_eval"] weight_old = csv_dict["num_examples_for_eval"] weight_new = csv_dict_new["num_examples_for_eval"] print("Weights for datasets:", weight_old, weight_new) print("new") for key, val in csv_dict.items(): if not isinstance(val, str): if key != "num_examples_for_eval": csv_dict_all[key] = ((val*weight_old + csv_dict_new[key]*weight_new)/total_num_samples) else: csv_dict_all[key] = (val + csv_dict_new[key]) else: csv_dict_all[key] = val csv_dict_all["Label type"] = "all_eval" with open("./output/results.csv", "a", newline='') as f: writer = csv.DictWriter(f, fieldnames=header) writer.writerow(csv_dict_all) csv_dict = copy.deepcopy(csv_dict_new) print(csv_dict) with open("./output/results.csv", "a", newline='') as f: writer = csv.DictWriter(f, fieldnames=header) writer.writerow(csv_dict) document_store.delete_index(document_index) if __name__ == "__main__": main()