gorkaartola commited on
Commit
0214f0b
1 Parent(s): 11fc8af

Update metric_for_tp_fp_samples.py

Browse files
Files changed (1) hide show
  1. metric_for_tp_fp_samples.py +4 -1
metric_for_tp_fp_samples.py CHANGED
@@ -191,7 +191,7 @@ class metric_tp_fp_Datasets(evaluate.Metric):
191
  return results
192
 
193
  #Computes the metric for each prediction strategy##############################################
194
- def _compute(self, predictions, references, prediction_strategies = [["argmax_max"],]):
195
  """Returns the scores"""
196
  # TODO: Compute the different scores of the metric
197
  predictions = torch.from_numpy(np.array(predictions, dtype = 'float32'))
@@ -223,6 +223,9 @@ class metric_tp_fp_Datasets(evaluate.Metric):
223
  if j[0] in i:
224
  TP_data.loc[TP_data["class"] == j[0], "coincidence count"] += 1
225
  TP_data = TP_data.sort_values(by=["class"], ignore_index = True)
 
 
 
226
  if j[1] == 2:
227
  FP_data.loc[FP_data["class"] == j[0], "number of samples"] += 1
228
  if len(i) >> 0:
 
191
  return results
192
 
193
  #Computes the metric for each prediction strategy##############################################
194
+ def _compute(self, predictions, references, prediction_strategies = [["argmax_max"],], FPifWrong = False):
195
  """Returns the scores"""
196
  # TODO: Compute the different scores of the metric
197
  predictions = torch.from_numpy(np.array(predictions, dtype = 'float32'))
 
223
  if j[0] in i:
224
  TP_data.loc[TP_data["class"] == j[0], "coincidence count"] += 1
225
  TP_data = TP_data.sort_values(by=["class"], ignore_index = True)
226
+ elif FPifWrong:
227
+ FP_data.loc[TP_data["class"] == j[0], "coincidence count"] += 1
228
+ FP_data = TP_data.sort_values(by=["class"], ignore_index = True)
229
  if j[1] == 2:
230
  FP_data.loc[FP_data["class"] == j[0], "number of samples"] += 1
231
  if len(i) >> 0: