Spaces:
Runtime error
Runtime error
gorkaartola
commited on
Commit
•
0214f0b
1
Parent(s):
11fc8af
Update metric_for_tp_fp_samples.py
Browse files
metric_for_tp_fp_samples.py
CHANGED
@@ -191,7 +191,7 @@ class metric_tp_fp_Datasets(evaluate.Metric):
|
|
191 |
return results
|
192 |
|
193 |
#Computes the metric for each prediction strategy##############################################
|
194 |
-
def _compute(self, predictions, references, prediction_strategies = [["argmax_max"],]):
|
195 |
"""Returns the scores"""
|
196 |
# TODO: Compute the different scores of the metric
|
197 |
predictions = torch.from_numpy(np.array(predictions, dtype = 'float32'))
|
@@ -223,6 +223,9 @@ class metric_tp_fp_Datasets(evaluate.Metric):
|
|
223 |
if j[0] in i:
|
224 |
TP_data.loc[TP_data["class"] == j[0], "coincidence count"] += 1
|
225 |
TP_data = TP_data.sort_values(by=["class"], ignore_index = True)
|
|
|
|
|
|
|
226 |
if j[1] == 2:
|
227 |
FP_data.loc[FP_data["class"] == j[0], "number of samples"] += 1
|
228 |
if len(i) >> 0:
|
|
|
191 |
return results
|
192 |
|
193 |
#Computes the metric for each prediction strategy##############################################
|
194 |
+
def _compute(self, predictions, references, prediction_strategies = [["argmax_max"],], FPifWrong = False):
|
195 |
"""Returns the scores"""
|
196 |
# TODO: Compute the different scores of the metric
|
197 |
predictions = torch.from_numpy(np.array(predictions, dtype = 'float32'))
|
|
|
223 |
if j[0] in i:
|
224 |
TP_data.loc[TP_data["class"] == j[0], "coincidence count"] += 1
|
225 |
TP_data = TP_data.sort_values(by=["class"], ignore_index = True)
|
226 |
+
elif FPifWrong:
|
227 |
+
FP_data.loc[TP_data["class"] == j[0], "coincidence count"] += 1
|
228 |
+
FP_data = TP_data.sort_values(by=["class"], ignore_index = True)
|
229 |
if j[1] == 2:
|
230 |
FP_data.loc[FP_data["class"] == j[0], "number of samples"] += 1
|
231 |
if len(i) >> 0:
|