File size: 5,754 Bytes
9bf0a0f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
# imports
import logging
#import torch_scatter
import argparse
import pickle
import csv
import collections
import itertools
import copy
from setup_database import get_document_store, add_data
from setup_modules import create_retriever, create_readers_and_pipeline, text_reader_types, table_reader_types
from eval_helper import create_labels

def parse_args():
   parser = argparse.ArgumentParser(description="JointXplore")

   parser.add_argument("--context", help='which information should be added as context, subset of [processed_website_tables, processed_website_text, processed_schedule_tables], enter as multiple strings', 
                     nargs='+', default=["processed_website_tables","processed_website_text","processed_schedule_tables"])
   parser.add_argument("--text_reader", help="specify the model to use as text reader", choices=["minilm", "distilroberta", "electra-base", "bert-base", "deberta-large", "gpt3"], default="bert-base")
   parser.add_argument("--api-key", help="if gpt3 choosen as reader, please provide api-key", action="store_true")
   parser.add_argument("--table_reader", help="choose tapas or convert table to text file and treat them as such", choices=["tapas", "text"], default="tapas")
   parser.add_argument("--seperate_evaluation", help="if specified, student generated questions and synthetically generated questions are evaluated seperately", action="store_true")
   args = parser.parse_args()
   # if 'LOCAL_RANK' not in os.environ:
   #     os.environ['LOCAL_RANK'] = str(args.local_rank)

   return args
 
def main(*args):
   if args=={}:
      args = parse_args()
      filenames = args.context
      text_reader = args.text_reader
      table_reader = args.table_reader
      seperate_evaluation = args.seperate_evaluation
   else:
      filenames, text_reader, table_reader, seperate_evaluation = args
   print(f"Filenames: {filenames}")
   use_table = False
   use_text = False
   if "processed_schedule_tables" in filenames:
      use_table = True
   if "processed_website_text" or "processed_website_tables" in filenames:
      use_text = True
   
   # configure logger
   logging.basicConfig(format="%(levelname)s - %(name)s -  %(message)s", level=logging.WARNING)
   logging.getLogger("haystack").setLevel(logging.WARNING)
   logging.info("Starting..")
   document_index = "document"
   document_store = get_document_store(document_index)
   print(f"Number of docs previously: {len(document_store.get_all_documents())}")
   document_store, data = add_data(filenames, document_store, document_index)
   print(f"Number of docs after: {len(document_store.get_all_documents())}")
   document_store, retriever = create_retriever(document_store)
   text_reader_type = text_reader_types[text_reader]
   table_reader_type = table_reader_types[table_reader]
   pipeline = create_readers_and_pipeline(retriever, text_reader_type, table_reader_type, use_table, use_text)
   
   with open("./output/results.csv", "r") as f:
      reader = csv.reader(f)
      for header in reader:
         break

   labels_file = "./data/validation_data/processed_qa.json"
   labels = create_labels(labels_file, data, seperate_evaluation)
   label_types = ["all_eval"]
   if seperate_evaluation:
      label_types = ["students", "synthetic"]
   for idx, label in enumerate(labels):
      # for la in label:
      #    for l in la.labels:
      #       print("CHECK WRONG DOC")
      #       print(l.document.content == "")
      #       print(l.document.id)
      print(f"Label Dataset: {idx}")
      results = pipeline.eval(label, params={"top_k": 10}, sas_model_name_or_path="cross-encoder/stsb-roberta-large")
      res_dict = results.calculate_metrics()
      print(res_dict)
      with open(f"./output/{text_reader}_{table_reader}_{('_').join(filenames)}_{label_types[idx]}", "wb") as fp:
         pickle.dump(results, fp)
      exp_dict = {
         "Text Reader": text_reader,
         "Table Reader": table_reader,
         "Context" : ('_').join(filenames),
         "Label type": label_types[idx]
      }
      if 'JoinAnswers' in res_dict:
         csv_dict_new = {**res_dict['EmbeddingRetriever'], **res_dict['JoinAnswers'], **exp_dict}
      elif 'TableReader' in res_dict:
         csv_dict_new = {**res_dict['EmbeddingRetriever'], **res_dict['TableReader'], **exp_dict}
      elif 'TextReader' in res_dict:
         csv_dict_new = {**res_dict['EmbeddingRetriever'], **res_dict['TextReader'], **exp_dict}
      if idx == 1:
         csv_dict_all = {}
         # iterating key, val with chain()
         total_num_samples = csv_dict["num_examples_for_eval"] + csv_dict_new["num_examples_for_eval"]
         weight_old = csv_dict["num_examples_for_eval"]
         weight_new = csv_dict_new["num_examples_for_eval"]
         print("Weights for datasets:", weight_old, weight_new)
         print("new")
         for key, val in csv_dict.items():
            if not isinstance(val, str):
               if key != "num_examples_for_eval":
                  csv_dict_all[key] = ((val*weight_old + csv_dict_new[key]*weight_new)/total_num_samples)
               else:
                  csv_dict_all[key] = (val + csv_dict_new[key])
            else:
               csv_dict_all[key] = val
         csv_dict_all["Label type"] = "all_eval"
         with open("./output/results.csv", "a", newline='') as f:
            writer = csv.DictWriter(f, fieldnames=header)
            writer.writerow(csv_dict_all)
      csv_dict = copy.deepcopy(csv_dict_new)
      print(csv_dict)
      
      with open("./output/results.csv", "a", newline='') as f:
         writer = csv.DictWriter(f, fieldnames=header)
         writer.writerow(csv_dict)
   
   document_store.delete_index(document_index)


if __name__ == "__main__":
   main()