nbansal commited on
Commit
42c888f
1 Parent(s): 2c33aa3

Updated the documentation and added more test cases.

Browse files
Files changed (3) hide show
  1. README.md +35 -19
  2. semf1.py +90 -55
  3. tests.py +320 -1
README.md CHANGED
@@ -25,49 +25,65 @@ summary with the reference overlap summary. It evaluates the semantic overlap su
25
  computes precision, recall and F1 scores.
26
 
27
  ## How to Use
28
- Sem-F1 takes 2 mandatory arguments:
29
- `predictions`: (a list of system generated documents in the form of sentences i.e. List[List[str]]),
30
- `references`: (a list of ground-truth documents in the form of sentences i.e. List[List[str]])
 
31
 
32
  ```python
33
  from evaluate import load
 
34
  predictions = [
35
  ["I go to School.", "You are stupid."],
36
  ["I love adventure sports."],
37
  ]
38
  references = [
39
  ["I go to School.", "You are stupid."],
40
- ["I love adventure sports."],
41
  ]
42
  metric = load("semf1")
43
  results = metric.compute(predictions=predictions, references=references)
 
 
44
  ```
45
 
46
- It also accepts another optional arguments:
47
-
48
- `model_type: Optional[str]`:
49
- The model to use for encoding the sentences.
50
- Options are:
51
- [`pv1`](https://huggingface.co/sentence-transformers/paraphrase-distilroberta-base-v1),
52
- [`stsb`](https://huggingface.co/sentence-transformers/stsb-roberta-large),
53
- [`use`](https://huggingface.co/sentence-transformers/use-cmlm-multilingual).
54
- The default value is `use`.
 
 
 
 
 
 
 
 
55
 
56
- [//]: # (### Inputs)
57
 
58
  [//]: # (*List all input arguments in the format below*)
59
 
60
  [//]: # (- **input_field** *(type): Definition of input, with explanation if necessary. State any default value(s).*)
61
 
62
  ### Output Values
 
 
 
 
63
 
64
- `precision`: The [precision](https://huggingface.co/metrics/precision) for each sentence from the `predictions` + `references` lists, which ranges from 0.0 to 1.0.
65
-
66
- `recall`: The [recall](https://huggingface.co/metrics/recall) for each sentence from the `predictions` + `references` lists, which ranges from 0.0 to 1.0.
67
 
68
- `f1`: The [F1 score](https://huggingface.co/metrics/f1) for each sentence from the `predictions` + `references` lists, which ranges from 0.0 to 1.0.
 
 
 
69
 
70
- [//]: # (#### Values from Popular Papers)
 
71
 
72
  [//]: # (*Give examples, preferrably with links to leaderboards or publications, to papers that have reported this metric, along with the values they have reported.*)
73
 
 
25
  computes precision, recall and F1 scores.
26
 
27
  ## How to Use
28
+
29
+ Sem-F1 takes 2 mandatory arguments:
30
+ - `predictions` - List of predictions. Format varies based on `tokenize_sentences` and `multi_references` flags.
31
+ - `references`: List of references. Format varies based on `tokenize_sentences` and `multi_references` flags.
32
 
33
  ```python
34
  from evaluate import load
35
+
36
  predictions = [
37
  ["I go to School.", "You are stupid."],
38
  ["I love adventure sports."],
39
  ]
40
  references = [
41
  ["I go to School.", "You are stupid."],
42
+ ["I love outdoor sports."],
43
  ]
44
  metric = load("semf1")
45
  results = metric.compute(predictions=predictions, references=references)
46
+ for score in results:
47
+ print(f"Precision: {score.precision}, Recall: {score.recall}, F1: {score.f1}")
48
  ```
49
 
50
+ Sem-F1 also accepts multiple optional arguments:
51
+ - `model_type (str)`: Model to use for encoding sentences. Options: ['pv1', 'stsb', 'use']
52
+ - `pv1` - [paraphrase-distilroberta-base-v1](https://huggingface.co/sentence-transformers/paraphrase-distilroberta-base-v1)
53
+ - `stsb` - [stsb-roberta-large](https://huggingface.co/sentence-transformers/stsb-roberta-large)
54
+ - `use` - [Universal Sentence Encoder](https://huggingface.co/sentence-transformers/use-cmlm-multilingual) (Default)
55
+ - `tokenize_sentences (bool)`: Flag to indicate whether to tokenize the sentences in the input documents. Default: True.
56
+ - `multi_references (bool)`: Flag to indicate whether multiple references are provided. Default: False.
57
+ - `gpu (Union[bool, str, int, List[Union[str, int]]])`: Whether to use GPU, CPU or multiple-processes for computation.
58
+ - `batch_size (int)`: Batch size for encoding. Default: 32.
59
+ - `verbose (bool)`: Flag to indicate verbose output. Default: False.
60
+
61
+ Refer to the inputs descriptions for more detailed usage as follows
62
+ ```python
63
+ import evaluate
64
+ metric = evaluate.load("semf1")
65
+ metric.inputs_description
66
+ ```
67
 
 
68
 
69
  [//]: # (*List all input arguments in the format below*)
70
 
71
  [//]: # (- **input_field** *(type): Definition of input, with explanation if necessary. State any default value(s).*)
72
 
73
  ### Output Values
74
+ List of `Scores` dataclass corresponding to each sample -
75
+ - `precision: float`: Precision score, which ranges from 0.0 to 1.0.
76
+ - `recall: List[float]`: Recall score corresponding to each reference
77
+ - `f1: float`: F1 score (between precision and average recall).
78
 
 
 
 
79
 
80
+ ## Future Extensions
81
+ Currently, we have only implemented the 3 encoders* that we experimented with in our
82
+ [paper](https://aclanthology.org/2022.emnlp-main.49/). However, it can easily with extended for more models by simply
83
+ extending the `Encoder` base class. (Refer to `encoder_models.py` file).
84
 
85
+ `*` *In out paper, we used the Tensorflow [version](https://www.tensorflow.org/hub/tutorials/semantic_similarity_with_tf_hub_universal_encoder)
86
+ of the USE model, however, in our current implementation, we used [PyTorch version](https://huggingface.co/sentence-transformers/use-cmlm-multilingual).*
87
 
88
  [//]: # (*Give examples, preferrably with links to leaderboards or publications, to papers that have reported this metric, along with the values they have reported.*)
89
 
semf1.py CHANGED
@@ -14,7 +14,6 @@
14
  # TODO: Add test cases, Remove tokenize_sentences flag since it can be determined from the input itself.
15
  """Sem-F1 metric"""
16
 
17
- from functools import partial
18
  from typing import List, Optional, Tuple
19
 
20
  import datasets
@@ -56,69 +55,93 @@ sentence level and computes precision, recall and F1 scores.
56
  """
57
 
58
  _KWARGS_DESCRIPTION = """
59
- Sem-F1 compares the system generated overlap summary with ground truth reference overlap.
 
60
 
61
  Args:
62
- predictions: list - List of predictions (Details below)
63
- references: list - List of references (Details below)
64
- reference should be a string with tokens separated by spaces.
65
- model_type: str - Model to use. [pv1, stsb, use]
66
- Options:
67
- pv1 - paraphrase-distilroberta-base-v1 (Default)
68
- stsb - stsb-roberta-large
69
- use - Universal Sentence Encoder
70
- tokenize_sentences: bool - Sentence tokenize the input document (prediction/reference). Default: True.
71
- gpu: Union[bool, int] - Whether to use GPU or CPU.
72
- Options:
73
  False - CPU (Default)
74
- True - GPU, device 0
75
- n: int - GPU, device n
76
- batch_size: int - Batch Size, Default = 32.
 
 
 
 
 
 
77
  Returns:
78
- precision: Precision.
79
- recall: Recall.
80
- f1: F1 score.
81
-
82
- There are 4 possible cases for inputs corresponding to predictions and references arguments
83
- Case 1: Multi-Ref = False, tokenize_sentences = False
84
- predictions: List[List[str]] - List of predictions where each prediction is a list of sentences.
 
 
85
  references: List[List[str]] - List of references where each reference is a list of sentences.
86
- Case 2: Multi-Ref = False, tokenize_sentences = True
87
- predictions: List[str] - List of predictions where each prediction is a document
88
- references: List[str] - List of references where each reference is a document
89
- Case 3: Multi-Ref = True, tokenize_sentences = False
90
- predictions: List[List[str]] - List of predictions where each prediction is a list of sentences.
91
- references: List[List[List[str]]] - List of multi-references i.e. [[r11, r12, ...], [r21, r22, ...], ...]
92
- where each rij is further a list of sentences
93
- Case 4: Multi-Ref = True, tokenize_sentences = True
94
- predictions: List[str] - List of predictions where each prediction is a document
95
- references: List[List[str]] - List of multi-references i.e. [[r11, r12, ...], [r21, r22, ...], ...]
96
- where each rij is a document
97
-
98
- This can be seen in the form of truth table as follows:
99
- Case | Multi-Ref | tokenize_sentences | predictions | references
100
- 1 | 0 | 0 | List[List[str]] | List[List[str]]
101
- 2 | 0 | 1 | List[str] | List[str]
102
- 3 | 1 | 0 | List[List[str]] | List[List[List[str]]]
103
- 4 | 1 | 1 | List[str] | List[List[str]]
104
-
105
- It is automatically determined whether it is Multi-Ref case Single-Ref case.
106
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
107
  Examples:
108
 
109
  >>> import evaluate
110
  >>> predictions = [
111
- ["I go to School.", "You are stupid."],
112
  ["I love adventure sports."],
113
  ]
114
  >>> references = [
115
- ["I go to School.", "You are stupid."],
116
- ["I love adventure sports."],
117
  ]
118
  >>> metric = evaluate.load("semf1")
119
  >>> results = metric.compute(predictions=predictions, references=references)
120
- >>> print([round(v, 2) for v in results["f1"]])
121
- [0.77, 0.56]
122
  """
123
 
124
 
@@ -194,7 +217,12 @@ def _validate_input_format(
194
  - `PREDICTION_TYPE` and `REFERENCE_TYPE` are defined at the top of the file
195
  """
196
 
197
- is_list_of_strings_at_depth = partial(is_nested_list_of_type, element_type=str)
 
 
 
 
 
198
  if tokenize_sentences and multi_references:
199
  condition = is_list_of_strings_at_depth(predictions, 1) and is_list_of_strings_at_depth(references, 2)
200
  elif not tokenize_sentences and multi_references:
@@ -225,7 +253,7 @@ class SemF1(evaluate.Metric):
225
  inputs_description=_KWARGS_DESCRIPTION,
226
  # This defines the format of each prediction and reference
227
  features=[
228
- # Multi References: False, Tokenize_Sentences = False
229
  datasets.Features(
230
  {
231
  # predictions: List[List[str]] - List of predictions where prediction is a list of sentences
@@ -234,7 +262,7 @@ class SemF1(evaluate.Metric):
234
  "references": datasets.Sequence(datasets.Value("string", id="sequence"), id="references"),
235
  }
236
  ),
237
- # Multi References: False, Tokenize_Sentences = True
238
  datasets.Features(
239
  {
240
  # predictions: List[str] - List of predictions
@@ -243,7 +271,7 @@ class SemF1(evaluate.Metric):
243
  "references": datasets.Value("string", id="sequence"),
244
  }
245
  ),
246
- # Multi References: True, Tokenize_Sentences = False
247
  datasets.Features(
248
  {
249
  # predictions: List[List[str]] - List of predictions where prediction is a list of sentences
@@ -255,7 +283,7 @@ class SemF1(evaluate.Metric):
255
  datasets.Sequence(datasets.Value("string", id="sequence"), id="ref"), id="references"),
256
  }
257
  ),
258
- # Multi References: True, Tokenize_Sentences = True
259
  datasets.Features(
260
  {
261
  # predictions: List[str] - List of predictions
@@ -319,6 +347,12 @@ class SemF1(evaluate.Metric):
319
  :return: List of Scores dataclass with precision, recall, and F1 scores.
320
  """
321
 
 
 
 
 
 
 
322
  # Validate inputs corresponding to flags
323
  _validate_input_format(tokenize_sentences, multi_references, predictions, references)
324
 
@@ -363,10 +397,11 @@ class SemF1(evaluate.Metric):
363
  # Precision: Concatenate all the sentences in all the references
364
  concat_refs = np.concatenate(refs, axis=0)
365
  precision, _ = _compute_cosine_similarity(preds, concat_refs)
 
366
 
367
  # Recall: Compute individually for each reference
368
  recall_scores = [_compute_cosine_similarity(r_embeds, preds) for r_embeds in refs]
369
- recall_scores = [r_scores for (r_scores, _) in recall_scores]
370
 
371
  results.append(Scores(precision, recall_scores))
372
 
 
14
  # TODO: Add test cases, Remove tokenize_sentences flag since it can be determined from the input itself.
15
  """Sem-F1 metric"""
16
 
 
17
  from typing import List, Optional, Tuple
18
 
19
  import datasets
 
55
  """
56
 
57
  _KWARGS_DESCRIPTION = """
58
+ Sem-F1 compares the system-generated summaries (predictions) with ground truth reference summaries (references)
59
+ using precision, recall, and F1 score based on sentence embeddings.
60
 
61
  Args:
62
+ predictions (list): List of predictions. Format varies based on `tokenize_sentences` and `multi_references` flags.
63
+ references (list): List of references. Format varies based on `tokenize_sentences` and `multi_references` flags.
64
+ model_type (str): Model to use for encoding sentences. Options: ['pv1', 'stsb', 'use']
65
+ pv1 - paraphrase-distilroberta-base-v1 (Default)
66
+ stsb - stsb-roberta-large
67
+ use - Universal Sentence Encoder
68
+ tokenize_sentences (bool): Flag to indicate whether to tokenize the sentences in the input documents. Default: True.
69
+ multi_references (bool): Flag to indicate whether multiple references are provided. Default is False.
70
+ gpu (Union[bool, str, int, List[Union[str, int]]]): Whether to use GPU or CPU for computation.
71
+ bool -
 
72
  False - CPU (Default)
73
+ True - GPU (device 0) if gpu is available else CPU
74
+ int -
75
+ n - GPU, device index n
76
+ str -
77
+ 'cuda', 'gpu', 'cpu'
78
+ List[Union[str, int]] - Multiple GPUs/cpus i.e. use multiple processes when computing embeddings
79
+ batch_size (int): Batch size for encoding. Default is 32.
80
+ verbose (bool): Flag to indicate verbose output. Default is False.
81
+
82
  Returns:
83
+ List of Scores dataclass with attributes as follows -
84
+ precision: float - precision score
85
+ recall: List[float] - List of recall scores corresponding to single/multiple references
86
+ f1: float - F1 score (between precision and average recall)
87
+
88
+ Examples of input formats:
89
+
90
+ Case 1: multi_references = False, tokenize_sentences = False
91
+ predictions: List[List[str]] - List of predictions where each prediction is a list of sentences.
92
  references: List[List[str]] - List of references where each reference is a list of sentences.
93
+ Example:
94
+ predictions = [["This is a prediction sentence 1.", "This is a prediction sentence 2."]]
95
+ references = [["This is a reference sentence 1.", "This is a reference sentence 2."]]
96
+
97
+ Case 2: multi_references = False, tokenize_sentences = True
98
+ predictions: List[str] - List of predictions where each prediction is a document.
99
+ references: List[str] - List of references where each reference is a document.
100
+ Example:
101
+ predictions = ["This is a prediction sentence 1. This is a prediction sentence 2."]
102
+ references = ["This is a reference sentence 1. This is a reference sentence 2."]
103
+
104
+ Case 3: multi_references = True, tokenize_sentences = False
105
+ predictions: List[List[str]] - List of predictions where each prediction is a list of sentences.
106
+ references: List[List[List[str]]] - List of references where each example has multi-references (List[r1, r2, ...])
107
+ and each ri is a List of sentences.
108
+ Example:
109
+ predictions = [["Prediction sentence 1.", "Prediction sentence 2."]]
110
+ references = [
111
+ [
112
+ ["Reference sentence 1.", "Reference sentence 2."], # Reference 1
113
+ ["Alternative reference 1.", "Alternative reference 2."], # Reference 2
114
+ ]
115
+ ]
116
+
117
+ Case 4: multi_references = True, tokenize_sentences = True
118
+ predictions: List[str] - List of predictions where each prediction is a document.
119
+ references: List[List[str]] - List of references where each example has multi-references (List[r1, r2, ...]) where
120
+ each r1 is a document.
121
+ Example:
122
+ predictions = ["Prediction sentence 1. Prediction sentence 2."]
123
+ references = [
124
+ [
125
+ "Reference sentence 1. Reference sentence 2.", # Reference 1
126
+ "Alternative reference 1. Alternative reference 2.", # Reference 2
127
+ ]
128
+ ]
129
+
130
  Examples:
131
 
132
  >>> import evaluate
133
  >>> predictions = [
134
+ ["I go to School. You are stupid."],
135
  ["I love adventure sports."],
136
  ]
137
  >>> references = [
138
+ ["I go to School. You are stupid."],
139
+ ["I love outdoor sports."],
140
  ]
141
  >>> metric = evaluate.load("semf1")
142
  >>> results = metric.compute(predictions=predictions, references=references)
143
+ >>> for score in results:
144
+ >>> print(f"Precision: {score.precision}, Recall: {score.recall}, F1: {score.f1}")
145
  """
146
 
147
 
 
217
  - `PREDICTION_TYPE` and `REFERENCE_TYPE` are defined at the top of the file
218
  """
219
 
220
+ if len(predictions) != len(references):
221
+ raise ValueError("Predictions and references must have the same length.")
222
+
223
+ def is_list_of_strings_at_depth(lst_obj, depth: int):
224
+ return is_nested_list_of_type(lst_obj, element_type=str, depth=depth)
225
+
226
  if tokenize_sentences and multi_references:
227
  condition = is_list_of_strings_at_depth(predictions, 1) and is_list_of_strings_at_depth(references, 2)
228
  elif not tokenize_sentences and multi_references:
 
253
  inputs_description=_KWARGS_DESCRIPTION,
254
  # This defines the format of each prediction and reference
255
  features=[
256
+ # F0: Multi References: False, Tokenize_Sentences = False
257
  datasets.Features(
258
  {
259
  # predictions: List[List[str]] - List of predictions where prediction is a list of sentences
 
262
  "references": datasets.Sequence(datasets.Value("string", id="sequence"), id="references"),
263
  }
264
  ),
265
+ # F1: Multi References: False, Tokenize_Sentences = True
266
  datasets.Features(
267
  {
268
  # predictions: List[str] - List of predictions
 
271
  "references": datasets.Value("string", id="sequence"),
272
  }
273
  ),
274
+ # F2: Multi References: True, Tokenize_Sentences = False
275
  datasets.Features(
276
  {
277
  # predictions: List[List[str]] - List of predictions where prediction is a list of sentences
 
283
  datasets.Sequence(datasets.Value("string", id="sequence"), id="ref"), id="references"),
284
  }
285
  ),
286
+ # F3: Multi References: True, Tokenize_Sentences = True
287
  datasets.Features(
288
  {
289
  # predictions: List[str] - List of predictions
 
347
  :return: List of Scores dataclass with precision, recall, and F1 scores.
348
  """
349
 
350
+ # Note: I have to specifically handle this case because the library considers the feature corresponding to
351
+ # this case (F2) as the feature for the other case (F0) i.e. it can't make any distinction between
352
+ # List[str] and List[List[str]]
353
+ if not tokenize_sentences and multi_references:
354
+ references = [[eval(ref) for ref in mul_ref_ex] for mul_ref_ex in references]
355
+
356
  # Validate inputs corresponding to flags
357
  _validate_input_format(tokenize_sentences, multi_references, predictions, references)
358
 
 
397
  # Precision: Concatenate all the sentences in all the references
398
  concat_refs = np.concatenate(refs, axis=0)
399
  precision, _ = _compute_cosine_similarity(preds, concat_refs)
400
+ precision = np.clip(precision, a_min=0.0, a_max=1.0).item()
401
 
402
  # Recall: Compute individually for each reference
403
  recall_scores = [_compute_cosine_similarity(r_embeds, preds) for r_embeds in refs]
404
+ recall_scores = [np.clip(r_scores, 0.0, 1.0).item() for (r_scores, _) in recall_scores]
405
 
406
  results.append(Scores(precision, recall_scores))
407
 
tests.py CHANGED
@@ -3,9 +3,12 @@ import unittest
3
 
4
  import numpy as np
5
  import torch
 
6
  from sentence_transformers import SentenceTransformer
 
7
 
8
  from encoder_models import SBertEncoder, get_encoder
 
9
  from utils import get_gpu, slice_embeddings, is_nested_list_of_type, flatten_list, compute_f1, Scores
10
 
11
 
@@ -178,5 +181,321 @@ class TestGetEncoder(unittest.TestCase):
178
  # self.assertEqual(encoder.verbose, verbose)
179
 
180
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
181
  if __name__ == '__main__':
182
- unittest.main()
 
 
3
 
4
  import numpy as np
5
  import torch
6
+ from numpy.testing import assert_almost_equal
7
  from sentence_transformers import SentenceTransformer
8
+ from sklearn.metrics.pairwise import cosine_similarity
9
 
10
  from encoder_models import SBertEncoder, get_encoder
11
+ from semf1 import SemF1, _compute_cosine_similarity, _validate_input_format
12
  from utils import get_gpu, slice_embeddings, is_nested_list_of_type, flatten_list, compute_f1, Scores
13
 
14
 
 
181
  # self.assertEqual(encoder.verbose, verbose)
182
 
183
 
184
+ class TestSemF1(unittest.TestCase):
185
+ def setUp(self):
186
+ self.semf1_metric = SemF1() # semf1_metric
187
+
188
+ # Example cases, #Samples = 1
189
+ self.untokenized_single_reference_predictions = [
190
+ "This is a prediction sentence 1. This is a prediction sentence 2."]
191
+ self.untokenized_single_reference_references = [
192
+ "This is a reference sentence 1. This is a reference sentence 2."]
193
+
194
+ self.tokenized_single_reference_predictions = [
195
+ ["This is a prediction sentence 1.", "This is a prediction sentence 2."],
196
+ ]
197
+ self.tokenized_single_reference_references = [
198
+ ["This is a reference sentence 1.", "This is a reference sentence 2."],
199
+ ]
200
+
201
+ self.untokenized_multi_reference_predictions = [
202
+ "Prediction sentence 1. Prediction sentence 2."
203
+ ]
204
+ self.untokenized_multi_reference_references = [
205
+ ["Reference sentence 1. Reference sentence 2.", "Alternative reference 1. Alternative reference 2."],
206
+ ]
207
+
208
+ self.tokenized_multi_reference_predictions = [
209
+ ["Prediction sentence 1.", "Prediction sentence 2."],
210
+ ]
211
+ self.tokenized_multi_reference_references = [
212
+ [
213
+ ["Reference sentence 1.", "Reference sentence 2."],
214
+ ["Alternative reference 1.", "Alternative reference 2."]
215
+ ],
216
+ ]
217
+
218
+ def test_untokenized_single_reference(self):
219
+ scores = self.semf1_metric.compute(
220
+ predictions=self.untokenized_single_reference_predictions,
221
+ references=self.untokenized_single_reference_references,
222
+ tokenize_sentences=True,
223
+ multi_references=False,
224
+ gpu=False,
225
+ batch_size=32,
226
+ verbose=False
227
+ )
228
+ self.assertIsInstance(scores, list)
229
+ self.assertEqual(len(scores), len(self.untokenized_single_reference_predictions))
230
+
231
+ def test_tokenized_single_reference(self):
232
+ scores = self.semf1_metric.compute(
233
+ predictions=self.tokenized_single_reference_predictions,
234
+ references=self.tokenized_single_reference_references,
235
+ tokenize_sentences=False,
236
+ multi_references=False,
237
+ gpu=False,
238
+ batch_size=32,
239
+ verbose=False
240
+ )
241
+ self.assertIsInstance(scores, list)
242
+ self.assertEqual(len(scores), len(self.tokenized_single_reference_predictions))
243
+
244
+ for score in scores:
245
+ self.assertIsInstance(score, Scores)
246
+ self.assertTrue(0.0 <= score.precision <= 1.0)
247
+ self.assertTrue(all(0.0 <= recall <= 1.0 for recall in score.recall))
248
+
249
+ def test_untokenized_multi_reference(self):
250
+ scores = self.semf1_metric.compute(
251
+ predictions=self.untokenized_multi_reference_predictions,
252
+ references=self.untokenized_multi_reference_references,
253
+ tokenize_sentences=True,
254
+ multi_references=True,
255
+ gpu=False,
256
+ batch_size=32,
257
+ verbose=False
258
+ )
259
+ self.assertIsInstance(scores, list)
260
+ self.assertEqual(len(scores), len(self.untokenized_multi_reference_predictions))
261
+
262
+ def test_tokenized_multi_reference(self):
263
+ scores = self.semf1_metric.compute(
264
+ predictions=self.tokenized_multi_reference_predictions,
265
+ references=self.tokenized_multi_reference_references,
266
+ tokenize_sentences=False,
267
+ multi_references=True,
268
+ gpu=False,
269
+ batch_size=32,
270
+ verbose=False
271
+ )
272
+ self.assertIsInstance(scores, list)
273
+ self.assertEqual(len(scores), len(self.tokenized_multi_reference_predictions))
274
+
275
+ for score in scores:
276
+ self.assertIsInstance(score, Scores)
277
+ self.assertTrue(0.0 <= score.precision <= 1.0)
278
+ self.assertTrue(all(0.0 <= recall <= 1.0 for recall in score.recall))
279
+
280
+ def test_same_predictions_and_references(self):
281
+ scores = self.semf1_metric.compute(
282
+ predictions=self.tokenized_single_reference_predictions,
283
+ references=self.tokenized_single_reference_predictions,
284
+ tokenize_sentences=False,
285
+ multi_references=False,
286
+ gpu=False,
287
+ batch_size=32,
288
+ verbose=False
289
+ )
290
+
291
+ self.assertIsInstance(scores, list)
292
+ self.assertEqual(len(scores), len(self.tokenized_single_reference_predictions))
293
+
294
+ for score in scores:
295
+ self.assertIsInstance(score, Scores)
296
+ self.assertAlmostEqual(score.precision, 1.0, places=6)
297
+ assert_almost_equal(score.recall, 1, decimal=5, err_msg="Not all values are almost equal to 1")
298
+
299
+ def test_exact_output_scores(self):
300
+ predictions = [
301
+ ["I go to School.", "You are stupid."],
302
+ ["I love adventure sports."],
303
+ ]
304
+ references = [
305
+ ["I go to playground.", "You are genius.", "You need to be admired."],
306
+ ["I love adventure sports."],
307
+ ]
308
+ scores = self.semf1_metric.compute(
309
+ predictions=predictions,
310
+ references=references,
311
+ tokenize_sentences=False,
312
+ multi_references=False,
313
+ gpu=False,
314
+ batch_size=32,
315
+ verbose=False,
316
+ model_type="use",
317
+ )
318
+
319
+ self.assertIsInstance(scores, list)
320
+ self.assertEqual(len(scores), len(predictions))
321
+
322
+ score = scores[0]
323
+ self.assertIsInstance(score, Scores)
324
+ self.assertAlmostEqual(score.precision, 0.73, places=2)
325
+ self.assertAlmostEqual(score.recall[0], 0.63, places=2)
326
+
327
+
328
+ class TestCosineSimilarity(unittest.TestCase):
329
+
330
+ def setUp(self):
331
+ # Sample embeddings for testing
332
+ self.pred_embeds = np.array([
333
+ [1, 0, 0],
334
+ [0, 1, 0],
335
+ [0, 0, 1]
336
+ ])
337
+ self.ref_embeds = np.array([
338
+ [1, 0, 0],
339
+ [0, 1, 0],
340
+ [0, 0, 1]
341
+ ])
342
+
343
+ self.pred_embeds_random = np.random.rand(3, 3)
344
+ self.ref_embeds_random = np.random.rand(3, 3)
345
+
346
+ def test_cosine_similarity_perfect_match(self):
347
+ precision, recall = _compute_cosine_similarity(self.pred_embeds, self.ref_embeds)
348
+
349
+ # Expected values are 1.0 for both precision and recall since embeddings are identical
350
+ self.assertAlmostEqual(precision, 1.0, places=5)
351
+ self.assertAlmostEqual(recall, 1.0, places=5)
352
+
353
+ def _test_cosine_similarity_base(self, pred_embeds, ref_embeds):
354
+ precision, recall = _compute_cosine_similarity(pred_embeds, ref_embeds)
355
+
356
+ # Calculate expected precision and recall using sklearn's cosine similarity function
357
+ cosine_scores = cosine_similarity(pred_embeds, ref_embeds)
358
+ expected_precision = np.mean(np.max(cosine_scores, axis=-1)).item()
359
+ expected_recall = np.mean(np.max(cosine_scores, axis=0)).item()
360
+
361
+ self.assertAlmostEqual(precision, expected_precision, places=5)
362
+ self.assertAlmostEqual(recall, expected_recall, places=5)
363
+
364
+ def test_cosine_similarity_random(self):
365
+ self._test_cosine_similarity_base(self.pred_embeds_random, self.ref_embeds_random)
366
+
367
+ def test_cosine_similarity_different_shapes(self):
368
+ pred_embeds_diff = np.random.rand(5, 3)
369
+ ref_embeds_diff = np.random.rand(3, 3)
370
+ self._test_cosine_similarity_base(pred_embeds_diff, ref_embeds_diff)
371
+
372
+
373
+ class TestValidateInputFormat(unittest.TestCase):
374
+ def setUp(self):
375
+ # Sample predictions and references for different scenarios where number of samples = 1
376
+ # Note: Naming Convention: # When tokenize_sentences = True (i.e. input is untokenized) and vice-versa
377
+
378
+ # When tokenize_sentences = True (untokenized input) and multi_references = False
379
+ self.untokenized_single_reference_predictions = [
380
+ "This is a prediction sentence 1. This is a prediction sentence 2."
381
+ ]
382
+ self.untokenized_single_reference_references = [
383
+ "This is a reference sentence 1. This is a reference sentence 2."
384
+ ]
385
+
386
+ # When tokenize_sentences = False (tokenized input) and multi_references = False
387
+ self.tokenized_single_reference_predictions = [
388
+ ["This is a prediction sentence 1.", "This is a prediction sentence 2."]
389
+ ]
390
+ self.tokenized_single_reference_references = [
391
+ ["This is a reference sentence 1.", "This is a reference sentence 2."]
392
+ ]
393
+
394
+ # When tokenize_sentences = True (untokenized input) and multi_references = True
395
+ self.untokenized_multi_reference_predictions = [
396
+ "This is a prediction sentence 1. This is a prediction sentence 2."
397
+ ]
398
+ self.untokenized_multi_reference_references = [
399
+ [
400
+ "This is a reference sentence 1. This is a reference sentence 2.",
401
+ "Another reference sentence."
402
+ ]
403
+ ]
404
+
405
+ # When tokenize_sentences = False (tokenized input) and multi_references = True
406
+ self.tokenized_multi_reference_predictions = [
407
+ ["This is a prediction sentence 1.", "This is a prediction sentence 2."]
408
+ ]
409
+ self.tokenized_multi_reference_references = [
410
+ [
411
+ ["This is a reference sentence 1.", "This is a reference sentence 2."],
412
+ ["Another reference sentence."]
413
+ ]
414
+ ]
415
+
416
+ def test_tokenized_sentences_true_multi_references_true(self):
417
+ # Invalid format should raise an error
418
+ with self.assertRaises(ValueError):
419
+ _validate_input_format(
420
+ True,
421
+ True,
422
+ self.tokenized_single_reference_predictions,
423
+ self.tokenized_single_reference_references,
424
+ )
425
+
426
+ # Valid format should pass without error
427
+ _validate_input_format(
428
+ True,
429
+ True,
430
+ self.untokenized_multi_reference_predictions,
431
+ self.untokenized_multi_reference_references,
432
+ )
433
+
434
+ def test_tokenized_sentences_false_multi_references_true(self):
435
+ # Invalid format should raise an error
436
+ with self.assertRaises(ValueError):
437
+ _validate_input_format(
438
+ False,
439
+ True,
440
+ self.untokenized_single_reference_predictions,
441
+ self.untokenized_multi_reference_references,
442
+ )
443
+
444
+ # Valid format should pass without error
445
+ _validate_input_format(
446
+ False,
447
+ True,
448
+ self.tokenized_multi_reference_predictions,
449
+ self.tokenized_multi_reference_references,
450
+ )
451
+
452
+ def test_tokenized_sentences_true_multi_references_false(self):
453
+ # Invalid format should raise an error
454
+ with self.assertRaises(ValueError):
455
+ _validate_input_format(
456
+ True,
457
+ False,
458
+ self.tokenized_single_reference_predictions,
459
+ self.tokenized_single_reference_references,
460
+ )
461
+
462
+ # Valid format should pass without error
463
+ _validate_input_format(
464
+ True,
465
+ False,
466
+ self.untokenized_single_reference_predictions,
467
+ self.untokenized_single_reference_references,
468
+ )
469
+
470
+ def test_tokenized_sentences_false_multi_references_false(self):
471
+ # Invalid format should raise an error
472
+ with self.assertRaises(ValueError):
473
+ _validate_input_format(
474
+ False,
475
+ False,
476
+ self.untokenized_single_reference_predictions,
477
+ self.untokenized_single_reference_references,
478
+ )
479
+
480
+ # Valid format should pass without error
481
+ _validate_input_format(
482
+ False,
483
+ False,
484
+ self.tokenized_single_reference_predictions,
485
+ self.tokenized_single_reference_references,
486
+ )
487
+
488
+ def test_mismatched_lengths(self):
489
+ # Length mismatch should raise an error
490
+ with self.assertRaises(ValueError):
491
+ _validate_input_format(
492
+ True,
493
+ True,
494
+ self.untokenized_single_reference_predictions,
495
+ [self.untokenized_single_reference_predictions[0], self.untokenized_single_reference_predictions[0]],
496
+ )
497
+
498
+
499
  if __name__ == '__main__':
500
+ unittest.main(verbosity=2)
501
+ # unittest.main()