lvwerra HF staff commited on
Commit
b2436ac
1 Parent(s): 01ab7ce

Update Space (evaluate main: e4a27243)

Browse files
Files changed (2) hide show
  1. requirements.txt +1 -1
  2. squad_v2.py +23 -4
requirements.txt CHANGED
@@ -1 +1 @@
1
- git+https://github.com/huggingface/evaluate@80448674f5447a9682afe051db243c4a13bfe4ff
 
1
+ git+https://github.com/huggingface/evaluate@e4a2724377909fe2aeb4357e3971e5a569673b39
squad_v2.py CHANGED
@@ -13,6 +13,8 @@
13
  # limitations under the License.
14
  """ SQuAD v2 metric. """
15
 
 
 
16
  import datasets
17
 
18
  import evaluate
@@ -87,13 +89,26 @@ Examples:
87
  """
88
 
89
 
 
 
 
 
 
 
 
 
90
  @evaluate.utils.file_utils.add_start_docstrings(_DESCRIPTION, _KWARGS_DESCRIPTION)
91
  class SquadV2(evaluate.Metric):
92
- def _info(self):
 
 
 
 
93
  return evaluate.MetricInfo(
94
  description=_DESCRIPTION,
95
  citation=_CITATION,
96
  inputs_description=_KWARGS_DESCRIPTION,
 
97
  features=datasets.Features(
98
  {
99
  "predictions": {
@@ -113,7 +128,7 @@ class SquadV2(evaluate.Metric):
113
  reference_urls=["https://rajpurkar.github.io/SQuAD-explorer/"],
114
  )
115
 
116
- def _compute(self, predictions, references, no_answer_threshold=1.0):
117
  no_answer_probabilities = {p["id"]: p["no_answer_probability"] for p in predictions}
118
  dataset = [{"paragraphs": [{"qas": references}]}]
119
  predictions = {p["id"]: p["prediction_text"] for p in predictions}
@@ -123,8 +138,12 @@ class SquadV2(evaluate.Metric):
123
  no_ans_qids = [k for k, v in qid_to_has_ans.items() if not v]
124
 
125
  exact_raw, f1_raw = get_raw_scores(dataset, predictions)
126
- exact_thresh = apply_no_ans_threshold(exact_raw, no_answer_probabilities, qid_to_has_ans, no_answer_threshold)
127
- f1_thresh = apply_no_ans_threshold(f1_raw, no_answer_probabilities, qid_to_has_ans, no_answer_threshold)
 
 
 
 
128
  out_eval = make_eval_dict(exact_thresh, f1_thresh)
129
 
130
  if has_ans_qids:
 
13
  # limitations under the License.
14
  """ SQuAD v2 metric. """
15
 
16
+ from dataclasses import dataclass
17
+
18
  import datasets
19
 
20
  import evaluate
 
89
  """
90
 
91
 
92
+ @dataclass
93
+ class SquadV2Config(evaluate.info.Config):
94
+
95
+ name: str = "default"
96
+
97
+ no_answer_threshold: float = 1.0
98
+
99
+
100
  @evaluate.utils.file_utils.add_start_docstrings(_DESCRIPTION, _KWARGS_DESCRIPTION)
101
  class SquadV2(evaluate.Metric):
102
+
103
+ CONFIG_CLASS = SquadV2Config
104
+ ALLOWED_CONFIG_NAMES = ["default"]
105
+
106
+ def _info(self, config):
107
  return evaluate.MetricInfo(
108
  description=_DESCRIPTION,
109
  citation=_CITATION,
110
  inputs_description=_KWARGS_DESCRIPTION,
111
+ config=config,
112
  features=datasets.Features(
113
  {
114
  "predictions": {
 
128
  reference_urls=["https://rajpurkar.github.io/SQuAD-explorer/"],
129
  )
130
 
131
+ def _compute(self, predictions, references):
132
  no_answer_probabilities = {p["id"]: p["no_answer_probability"] for p in predictions}
133
  dataset = [{"paragraphs": [{"qas": references}]}]
134
  predictions = {p["id"]: p["prediction_text"] for p in predictions}
 
138
  no_ans_qids = [k for k, v in qid_to_has_ans.items() if not v]
139
 
140
  exact_raw, f1_raw = get_raw_scores(dataset, predictions)
141
+ exact_thresh = apply_no_ans_threshold(
142
+ exact_raw, no_answer_probabilities, qid_to_has_ans, self.config.no_answer_threshold
143
+ )
144
+ f1_thresh = apply_no_ans_threshold(
145
+ f1_raw, no_answer_probabilities, qid_to_has_ans, self.config.no_answer_threshold
146
+ )
147
  out_eval = make_eval_dict(exact_thresh, f1_thresh)
148
 
149
  if has_ans_qids: