GGroenendaal commited on
Commit
aa426fb
1 Parent(s): 90fe7fe

minor renaming and cleanup

Browse files
Files changed (1) hide show
  1. base_model/retriever.py +16 -13
base_model/retriever.py CHANGED
@@ -22,7 +22,7 @@ class Retriever:
22
  based on https://huggingface.co/docs/datasets/faiss_es#faiss.
23
  """
24
 
25
- def __init__(self, dataset: str = "GroNLP/ik-nlp-22_slp") -> None:
26
  """Initialize the retriever
27
 
28
  Args:
@@ -49,12 +49,12 @@ class Retriever:
49
  )
50
 
51
  # Dataset building
52
- self.dataset = self.__init_dataset(dataset)
 
53
 
54
-
55
- def __init_dataset(self,
56
- dataset: str,
57
- fname: str = "./models/paragraphs_embedding.faiss"):
58
  """Loads the dataset and adds FAISS embeddings.
59
 
60
  Args:
@@ -67,12 +67,12 @@ class Retriever:
67
  embeddings.
68
  """
69
  # Load dataset
70
- ds = load_dataset(dataset, name="paragraphs")["train"]
71
  print(ds)
72
 
73
- if os.path.exists(fname):
74
  # If we already have FAISS embeddings, load them from disk
75
- ds.load_faiss_index('embeddings', fname)
76
  return ds
77
  else:
78
  # If there are no FAISS embeddings, generate them
@@ -91,7 +91,7 @@ class Retriever:
91
 
92
  # save dataset w/ embeddings
93
  os.makedirs("./models/", exist_ok=True)
94
- ds_with_embeddings.save_faiss_index("embeddings", fname)
95
 
96
  return ds_with_embeddings
97
 
@@ -127,7 +127,8 @@ class Retriever:
127
  float: overall exact match
128
  float: overall F1-score
129
  """
130
- questions_ds = load_dataset("GroNLP/ik-nlp-22_slp", name="questions")['test']
 
131
  questions = questions_ds['question']
132
  answers = questions_ds['answer']
133
 
@@ -140,7 +141,9 @@ class Retriever:
140
  scores += score[0]
141
  predictions.append(result['text'][0])
142
 
143
- exact_matches = [evaluate.compute_exact_match(predictions[i], answers[i]) for i in range(len(answers))]
144
- f1_scores = [evaluate.compute_f1(predictions[i], answers[i]) for i in range(len(answers))]
 
 
145
 
146
  return sum(exact_matches) / len(exact_matches), sum(f1_scores) / len(f1_scores)
 
22
  based on https://huggingface.co/docs/datasets/faiss_es#faiss.
23
  """
24
 
25
+ def __init__(self, dataset_name: str = "GroNLP/ik-nlp-22_slp") -> None:
26
  """Initialize the retriever
27
 
28
  Args:
 
49
  )
50
 
51
  # Dataset building
52
+ self.dataset_name = dataset_name
53
+ self.dataset = self._init_dataset(dataset_name)
54
 
55
+ def _init_dataset(self,
56
+ dataset_name: str,
57
+ embedding_path: str = "./models/paragraphs_embedding.faiss"):
 
58
  """Loads the dataset and adds FAISS embeddings.
59
 
60
  Args:
 
67
  embeddings.
68
  """
69
  # Load dataset
70
+ ds = load_dataset(dataset_name, name="paragraphs")["train"]
71
  print(ds)
72
 
73
+ if os.path.exists(embedding_path):
74
  # If we already have FAISS embeddings, load them from disk
75
+ ds.load_faiss_index('embeddings', embedding_path)
76
  return ds
77
  else:
78
  # If there are no FAISS embeddings, generate them
 
91
 
92
  # save dataset w/ embeddings
93
  os.makedirs("./models/", exist_ok=True)
94
+ ds_with_embeddings.save_faiss_index("embeddings", embedding_path)
95
 
96
  return ds_with_embeddings
97
 
 
127
  float: overall exact match
128
  float: overall F1-score
129
  """
130
+ questions_ds = load_dataset(
131
+ self.dataset_name, name="questions")['test']
132
  questions = questions_ds['question']
133
  answers = questions_ds['answer']
134
 
 
141
  scores += score[0]
142
  predictions.append(result['text'][0])
143
 
144
+ exact_matches = [evaluate.compute_exact_match(
145
+ predictions[i], answers[i]) for i in range(len(answers))]
146
+ f1_scores = [evaluate.compute_f1(
147
+ predictions[i], answers[i]) for i in range(len(answers))]
148
 
149
  return sum(exact_matches) / len(exact_matches), sum(f1_scores) / len(f1_scores)