Spaces:
Runtime error
Runtime error
added aggregate function to semf1 metric. added test to check the functionality
Browse files
semf1.py
CHANGED
|
@@ -339,7 +339,7 @@ class SemF1(evaluate.Metric):
|
|
| 339 |
gpu: DEVICE_TYPE = False,
|
| 340 |
batch_size: int = 32,
|
| 341 |
verbose: bool = False,
|
| 342 |
-
aggregate: bool =
|
| 343 |
) -> List[Scores]:
|
| 344 |
"""
|
| 345 |
Compute precision, recall, and F1 scores for given predictions and references.
|
|
@@ -421,7 +421,22 @@ class SemF1(evaluate.Metric):
|
|
| 421 |
recall_scores = [np.clip(r_scores, 0.0, 1.0).item() for (r_scores, _) in recall_scores]
|
| 422 |
|
| 423 |
results.append(Scores(precision, recall_scores))
|
| 424 |
-
|
| 425 |
-
|
| 426 |
-
|
| 427 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 339 |
gpu: DEVICE_TYPE = False,
|
| 340 |
batch_size: int = 32,
|
| 341 |
verbose: bool = False,
|
| 342 |
+
aggregate: bool = False,
|
| 343 |
) -> List[Scores]:
|
| 344 |
"""
|
| 345 |
Compute precision, recall, and F1 scores for given predictions and references.
|
|
|
|
| 421 |
recall_scores = [np.clip(r_scores, 0.0, 1.0).item() for (r_scores, _) in recall_scores]
|
| 422 |
|
| 423 |
results.append(Scores(precision, recall_scores))
|
| 424 |
+
|
| 425 |
+
# runn aggregation procedure
|
| 426 |
+
if aggregate:
|
| 427 |
+
mean_prec = np.mean(
|
| 428 |
+
[score.precision for score in results]
|
| 429 |
+
)
|
| 430 |
+
mean_recall = np.mean(np.concatenate(
|
| 431 |
+
[np.array(score.recall) for score in results]
|
| 432 |
+
))
|
| 433 |
+
aggregated_score = Scores(
|
| 434 |
+
float(mean_prec),
|
| 435 |
+
[float(mean_recall)]
|
| 436 |
+
)
|
| 437 |
+
aggregated_score.f1 = float(np.mean(
|
| 438 |
+
[score.f1 for score in results]
|
| 439 |
+
))
|
| 440 |
+
results = aggregated_score
|
| 441 |
+
|
| 442 |
+
return results
|
tests.py
CHANGED
|
@@ -6,6 +6,7 @@ import torch
|
|
| 6 |
from numpy.testing import assert_almost_equal
|
| 7 |
from sentence_transformers import SentenceTransformer
|
| 8 |
from sklearn.metrics.pairwise import cosine_similarity
|
|
|
|
| 9 |
|
| 10 |
from .encoder_models import SBertEncoder, get_encoder
|
| 11 |
from .semf1 import SemF1, _compute_cosine_similarity, _validate_input_format
|
|
@@ -13,6 +14,14 @@ from .utils import get_gpu, slice_embeddings, is_nested_list_of_type, flatten_li
|
|
| 13 |
|
| 14 |
|
| 15 |
class TestUtils(unittest.TestCase):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 16 |
def test_get_gpu(self):
|
| 17 |
gpu_count = torch.cuda.device_count()
|
| 18 |
gpu_available = torch.cuda.is_available()
|
|
@@ -231,6 +240,32 @@ class TestSemF1(unittest.TestCase):
|
|
| 231 |
["Alternative reference 1.", "Alternative reference 2."]
|
| 232 |
],
|
| 233 |
]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 234 |
|
| 235 |
def test_untokenized_single_reference(self):
|
| 236 |
scores = self.semf1_metric.compute(
|
|
@@ -600,5 +635,8 @@ class TestValidateInputFormat(unittest.TestCase):
|
|
| 600 |
)
|
| 601 |
|
| 602 |
|
| 603 |
-
|
| 604 |
unittest.main(verbosity=2)
|
|
|
|
|
|
|
|
|
|
|
|
| 6 |
from numpy.testing import assert_almost_equal
|
| 7 |
from sentence_transformers import SentenceTransformer
|
| 8 |
from sklearn.metrics.pairwise import cosine_similarity
|
| 9 |
+
from unittest import TestLoader
|
| 10 |
|
| 11 |
from .encoder_models import SBertEncoder, get_encoder
|
| 12 |
from .semf1 import SemF1, _compute_cosine_similarity, _validate_input_format
|
|
|
|
| 14 |
|
| 15 |
|
| 16 |
class TestUtils(unittest.TestCase):
|
| 17 |
+
def runTest(self):
|
| 18 |
+
self.test_get_gpu()
|
| 19 |
+
self.test_slice_embeddings()
|
| 20 |
+
self.test_is_nested_list_of_type()
|
| 21 |
+
self.test_flatten_list()
|
| 22 |
+
self.test_compute_f1()
|
| 23 |
+
self.test_scores()
|
| 24 |
+
|
| 25 |
def test_get_gpu(self):
|
| 26 |
gpu_count = torch.cuda.device_count()
|
| 27 |
gpu_available = torch.cuda.is_available()
|
|
|
|
| 240 |
["Alternative reference 1.", "Alternative reference 2."]
|
| 241 |
],
|
| 242 |
]
|
| 243 |
+
self.multi_sample_refs = [
|
| 244 |
+
'this is the first reference sample',
|
| 245 |
+
'this is the second reference sample',
|
| 246 |
+
]
|
| 247 |
+
self.multi_sample_preds = [
|
| 248 |
+
'this is the first prediction sample',
|
| 249 |
+
'this is the second prediction sample',
|
| 250 |
+
]
|
| 251 |
+
|
| 252 |
+
def test_aggregate_flag(self):
|
| 253 |
+
"""
|
| 254 |
+
check if a `Scores` class is returned instead of a list of
|
| 255 |
+
`Scores`
|
| 256 |
+
"""
|
| 257 |
+
scores = self.semf1_metric.compute(
|
| 258 |
+
predictions=self.multi_sample_preds,
|
| 259 |
+
references=self.multi_sample_refs,
|
| 260 |
+
tokenize_sentences=True,
|
| 261 |
+
multi_references=False,
|
| 262 |
+
gpu=False,
|
| 263 |
+
batch_size=32,
|
| 264 |
+
verbose=False,
|
| 265 |
+
aggregate=True,
|
| 266 |
+
)
|
| 267 |
+
self.assertIsInstance(scores, Scores)
|
| 268 |
+
|
| 269 |
|
| 270 |
def test_untokenized_single_reference(self):
|
| 271 |
scores = self.semf1_metric.compute(
|
|
|
|
| 635 |
)
|
| 636 |
|
| 637 |
|
| 638 |
+
def run_tests():
|
| 639 |
unittest.main(verbosity=2)
|
| 640 |
+
|
| 641 |
+
if __name__ == '__main__':
|
| 642 |
+
run_tests()
|