nbansal commited on
Commit
57111be
1 Parent(s): 0377c9d

Handled the None and empty string cases

Browse files
Files changed (3) hide show
  1. semf1.py +26 -17
  2. tests.py +104 -12
  3. utils.py +78 -34
semf1.py CHANGED
@@ -27,7 +27,7 @@ from sklearn.metrics.pairwise import cosine_similarity
27
 
28
  from .encoder_models import get_encoder
29
  from .type_aliases import DEVICE_TYPE, PREDICTION_TYPE, REFERENCE_TYPE
30
- from .utils import is_nested_list_of_type, Scores, slice_embeddings, flatten_list, get_gpu
31
 
32
  _CITATION = """\
33
  @inproceedings{bansal-etal-2022-sem,
@@ -223,22 +223,33 @@ def _validate_input_format(
223
  """
224
 
225
  if len(predictions) != len(references):
226
- raise ValueError("Predictions and references must have the same length.")
 
227
 
228
  def is_list_of_strings_at_depth(lst_obj, depth: int):
229
  return is_nested_list_of_type(lst_obj, element_type=str, depth=depth)
230
 
231
- if tokenize_sentences and multi_references:
232
- condition = is_list_of_strings_at_depth(predictions, 1) and is_list_of_strings_at_depth(references, 2)
233
- elif not tokenize_sentences and multi_references:
234
- condition = is_list_of_strings_at_depth(predictions, 2) and is_list_of_strings_at_depth(references, 3)
235
- elif tokenize_sentences and not multi_references:
236
- condition = is_list_of_strings_at_depth(predictions, 1) and is_list_of_strings_at_depth(references, 1)
237
- else:
238
- condition = is_list_of_strings_at_depth(predictions, 2) and is_list_of_strings_at_depth(references, 2)
239
-
240
- if not condition:
241
- raise ValueError("Predictions are references are not valid input format. Refer to documentation.")
 
 
 
 
 
 
 
 
 
 
242
 
243
 
244
  @evaluate.utils.file_utils.add_start_docstrings(_DESCRIPTION, _KWARGS_DESCRIPTION)
@@ -317,8 +328,6 @@ class SemF1(evaluate.Metric):
317
  """Optional: download external resources useful to compute the scores"""
318
  import nltk
319
  nltk.download("punkt", quiet=True)
320
- # if not nltk.data.find("tokenizers/punkt"): # TODO: check why it is not working
321
- # pass
322
 
323
  def _compute(
324
  self,
@@ -377,8 +386,8 @@ class SemF1(evaluate.Metric):
377
 
378
  # Tokenize sentences if required
379
  if tokenize_sentences:
380
- predictions = [nltk.tokenize.sent_tokenize(pred) for pred in predictions]
381
- references = [[nltk.tokenize.sent_tokenize(ref) for ref in refs] for refs in references]
382
 
383
  # Flatten the data for batch processing
384
  all_sentences = flatten_list(predictions) + flatten_list(references)
 
27
 
28
  from .encoder_models import get_encoder
29
  from .type_aliases import DEVICE_TYPE, PREDICTION_TYPE, REFERENCE_TYPE
30
+ from .utils import is_nested_list_of_type, Scores, slice_embeddings, flatten_list, get_gpu, sent_tokenize
31
 
32
  _CITATION = """\
33
  @inproceedings{bansal-etal-2022-sem,
 
223
  """
224
 
225
  if len(predictions) != len(references):
226
+ raise ValueError(f"Predictions and references must have the same length. "
227
+ f"Got {len(predictions)} predictions and {len(references)} references.")
228
 
229
  def is_list_of_strings_at_depth(lst_obj, depth: int):
230
  return is_nested_list_of_type(lst_obj, element_type=str, depth=depth)
231
 
232
+ def check_format(lst_obj, expected_depth: int, name: str):
233
+ is_valid, error_message = is_list_of_strings_at_depth(lst_obj, expected_depth)
234
+ if not is_valid:
235
+ raise ValueError(f"{name} are not in the expected format.\n"
236
+ f"Error: {error_message}.")
237
+
238
+ try:
239
+ if tokenize_sentences and multi_references:
240
+ check_format(predictions, 1, "Predictions")
241
+ check_format(references, 2, "References")
242
+ elif not tokenize_sentences and multi_references:
243
+ check_format(predictions, 2, "Predictions")
244
+ check_format(references, 3, "References")
245
+ elif tokenize_sentences and not multi_references:
246
+ check_format(predictions, 1, "Predictions")
247
+ check_format(references, 1, "References")
248
+ else:
249
+ check_format(predictions, 2, "Predictions")
250
+ check_format(references, 2, "References")
251
+ except ValueError as ve:
252
+ raise ValueError(f"Input validation error: {ve}")
253
 
254
 
255
  @evaluate.utils.file_utils.add_start_docstrings(_DESCRIPTION, _KWARGS_DESCRIPTION)
 
328
  """Optional: download external resources useful to compute the scores"""
329
  import nltk
330
  nltk.download("punkt", quiet=True)
 
 
331
 
332
  def _compute(
333
  self,
 
386
 
387
  # Tokenize sentences if required
388
  if tokenize_sentences:
389
+ predictions = [sent_tokenize(pred) for pred in predictions]
390
+ references = [[sent_tokenize(ref) for ref in refs] for refs in references]
391
 
392
  # Flatten the data for batch processing
393
  all_sentences = flatten_list(predictions) + flatten_list(references)
tests.py CHANGED
@@ -1,6 +1,5 @@
1
  import statistics
2
  import unittest
3
- from unittest.mock import patch, MagicMock
4
 
5
  import numpy as np
6
  import torch
@@ -73,29 +72,36 @@ class TestUtils(unittest.TestCase):
73
 
74
  def test_is_nested_list_of_type(self):
75
  # Test case: Depth 0, single element matching element_type
76
- self.assertTrue(is_nested_list_of_type("test", str, 0))
77
 
78
  # Test case: Depth 0, single element not matching element_type
79
- self.assertFalse(is_nested_list_of_type("test", int, 0))
 
80
 
81
  # Test case: Depth 1, list of elements matching element_type
82
- self.assertTrue(is_nested_list_of_type(["apple", "banana"], str, 1))
83
 
84
  # Test case: Depth 1, list of elements not matching element_type
85
- self.assertFalse(is_nested_list_of_type([1, 2, 3], str, 1))
 
86
 
87
  # Test case: Depth 0 (Wrong), list of elements matching element_type
88
- self.assertFalse(is_nested_list_of_type([1, 2, 3], str, 0))
 
89
 
90
  # Depth 2
91
- self.assertTrue(is_nested_list_of_type([[1, 2], [3, 4]], int, 2))
92
- self.assertTrue(is_nested_list_of_type([['1', '2'], ['3', '4']], str, 2))
93
- self.assertFalse(is_nested_list_of_type([[1, 2], ["a", "b"]], int, 2))
 
 
94
 
95
  # Depth 3
96
- self.assertFalse(is_nested_list_of_type([[[1], [2]], [[3], [4]]], list, 3))
97
- self.assertTrue(is_nested_list_of_type([[[1], [2]], [[3], [4]]], int, 3))
 
98
 
 
99
  with self.assertRaises(ValueError):
100
  is_nested_list_of_type([1, 2], int, -1)
101
 
@@ -335,6 +341,93 @@ class TestSemF1(unittest.TestCase):
335
  self.assertAlmostEqual(score.precision, 0.73, places=2)
336
  self.assertAlmostEqual(score.recall[0], 0.63, places=2)
337
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
338
 
339
  class TestCosineSimilarity(unittest.TestCase):
340
 
@@ -509,4 +602,3 @@ class TestValidateInputFormat(unittest.TestCase):
509
 
510
  if __name__ == '__main__':
511
  unittest.main(verbosity=2)
512
-
 
1
  import statistics
2
  import unittest
 
3
 
4
  import numpy as np
5
  import torch
 
72
 
73
  def test_is_nested_list_of_type(self):
74
  # Test case: Depth 0, single element matching element_type
75
+ self.assertEqual(is_nested_list_of_type("test", str, 0), (True, ""))
76
 
77
  # Test case: Depth 0, single element not matching element_type
78
+ is_valid, err_msg = is_nested_list_of_type("test", int, 0)
79
+ self.assertEqual(is_valid, False)
80
 
81
  # Test case: Depth 1, list of elements matching element_type
82
+ self.assertEqual(is_nested_list_of_type(["apple", "banana"], str, 1), (True, ""))
83
 
84
  # Test case: Depth 1, list of elements not matching element_type
85
+ is_valid, err_msg = is_nested_list_of_type([1, 2, 3], str, 1)
86
+ self.assertEqual(is_valid, False)
87
 
88
  # Test case: Depth 0 (Wrong), list of elements matching element_type
89
+ is_valid, err_msg = is_nested_list_of_type([1, 2, 3], str, 0)
90
+ self.assertEqual(is_valid, False)
91
 
92
  # Depth 2
93
+ self.assertEqual(is_nested_list_of_type([[1, 2], [3, 4]], int, 2), (True, ""))
94
+ self.assertEqual(is_nested_list_of_type([['1', '2'], ['3', '4']], str, 2), (True, ""))
95
+ is_valid, err_msg = is_nested_list_of_type([[1, 2], ["a", "b"]], int, 2)
96
+ self.assertEqual(is_valid, False)
97
+
98
 
99
  # Depth 3
100
+ is_valid, err_msg = is_nested_list_of_type([[[1], [2]], [[3], [4]]], list, 3)
101
+ self.assertEqual(is_valid, False)
102
+ self.assertEqual(is_nested_list_of_type([[[1], [2]], [[3], [4]]], int, 3), (True, ""))
103
 
104
+ # Test case: Depth is negative, expecting ValueError
105
  with self.assertRaises(ValueError):
106
  is_nested_list_of_type([1, 2], int, -1)
107
 
 
341
  self.assertAlmostEqual(score.precision, 0.73, places=2)
342
  self.assertAlmostEqual(score.recall[0], 0.63, places=2)
343
 
344
+ def test_none_input(self):
345
+ def _call_metric(preds, refs, tok, mul_ref):
346
+ with self.assertRaises(ValueError) as ctx:
347
+ _ = self.semf1_metric.compute(
348
+ predictions=preds,
349
+ references=refs,
350
+ tokenize_sentences=tok,
351
+ multi_references=mul_ref,
352
+ gpu=False,
353
+ batch_size=32,
354
+ verbose=False,
355
+ model_type="use",
356
+ )
357
+ print(f"Raised ValueError with message: {ctx.exception}")
358
+ return ""
359
+
360
+ # # Case 1: tokenize_sentences = True, multi_references = True
361
+ tokenize_sentences = True
362
+ multi_references = True
363
+ predictions = [
364
+ "I go to School. You are stupid.",
365
+ "I go to School. You are stupid.",
366
+ ]
367
+ references = [
368
+ ["I am", "I am"],
369
+ [None, "I am"],
370
+ ]
371
+ print(f"Case I\n{_call_metric(predictions, references, tokenize_sentences, multi_references)}\n")
372
+
373
+ # Case 2: tokenize_sentences = False, multi_references = True
374
+ tokenize_sentences = False
375
+ multi_references = True
376
+ predictions = [
377
+ ["I go to School.", "You are stupid."],
378
+ ["I go to School.", "You are stupid."],
379
+ ]
380
+ references = [
381
+ [["I am", "I am"], [None, "I am"]],
382
+ [[None, "I am"]],
383
+ ]
384
+ print(f"Case II\n{_call_metric(predictions, references, tokenize_sentences, multi_references)}\n")
385
+
386
+ # Case 3: tokenize_sentences = True, multi_references = False
387
+ tokenize_sentences = True
388
+ multi_references = False
389
+ predictions = [
390
+ None,
391
+ "I go to School. You are stupid.",
392
+ ]
393
+ references = [
394
+ "I am. I am.",
395
+ "I am. I am.",
396
+ ]
397
+ print(f"Case III\n{_call_metric(predictions, references, tokenize_sentences, multi_references)}\n")
398
+
399
+ # Case 4: tokenize_sentences = False, multi_references = False
400
+ # This is taken care by the library itself
401
+ tokenize_sentences = False
402
+ multi_references = False
403
+ predictions = [
404
+ ["I go to School.", None],
405
+ ["I go to School.", "You are stupid."],
406
+ ]
407
+ references = [
408
+ ["I am.", "I am."],
409
+ ["I am.", "I am."],
410
+ ]
411
+ print(f"Case IV\n{_call_metric(predictions, references, tokenize_sentences, multi_references)}\n")
412
+
413
+ def test_empty_input(self):
414
+ predictions = [""]
415
+ references = ["I go to School. You are stupid."]
416
+ scores = self.semf1_metric.compute(
417
+ predictions=predictions,
418
+ references=references,
419
+ )
420
+ print(scores)
421
+
422
+ # # Test with Gibberish Cases
423
+ # predictions = ["lth cgezawrxretxdr", "dsfgsdfhsdfh"]
424
+ # references = ["dzfgzeWfnAfse", "dtjsrtzerZJSEWr"]
425
+ # scores = self.semf1_metric.compute(
426
+ # predictions=predictions,
427
+ # references=references,
428
+ # )
429
+ # print(scores)
430
+
431
 
432
  class TestCosineSimilarity(unittest.TestCase):
433
 
 
602
 
603
  if __name__ == '__main__':
604
  unittest.main(verbosity=2)
 
utils.py CHANGED
@@ -1,8 +1,10 @@
1
  import statistics
 
2
  import sys
3
  from dataclasses import dataclass, field
4
- from typing import List, Union
5
 
 
6
  import torch
7
  from numpy.typing import NDArray
8
 
@@ -149,44 +151,65 @@ def slice_embeddings(embeddings: NDArray, num_sentences: NumSentencesType) -> Em
149
  raise TypeError(f"Incorrect Type for {num_sentences=}")
150
 
151
 
152
- def is_nested_list_of_type(lst_obj, element_type, depth: int) -> bool:
153
  """
154
- Check if the given object is a nested list of a specific type up to a specified depth.
155
 
156
- Args:
157
- - lst_obj: The object to check, expected to be a list or a single element.
158
- - element_type: The type that each element in the nested list should match.
159
- - depth (int): The depth of nesting to check. Must be non-negative.
160
 
161
- Returns:
162
- - bool: True if lst_obj is a nested list of the specified type up to the given depth, False otherwise.
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
163
 
164
- Raises:
165
- - ValueError: If depth is negative.
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
166
 
167
- Example:
168
- ```python
169
- # Test cases
170
- is_nested_list_of_type("test", str, 0) # Returns True
171
- is_nested_list_of_type([1, 2, 3], str, 0) # Returns False
172
- is_nested_list_of_type(["apple", "banana"], str, 1) # Returns True
173
- is_nested_list_of_type([[1, 2], [3, 4]], int, 2) # Returns True
174
- is_nested_list_of_type([[1, 2], ["a", "b"]], int, 2) # Returns False
175
- is_nested_list_of_type([[[1], [2]], [[3], [4]]], int, 3) # Returns True
176
- ```
177
-
178
- Explanation:
179
- - The function checks if `lst_obj` is a nested list of elements of type `element_type` up to `depth` levels deep.
180
- - If `depth` is 0, it checks if `lst_obj` itself is of type `element_type`.
181
- - If `depth` is greater than 0, it recursively checks each level of nesting to ensure all elements match `element_type`.
182
- - Raises a `ValueError` if `depth` is negative, as depth must be a non-negative integer.
183
- """
184
- if depth == 0:
185
- return isinstance(lst_obj, element_type)
186
- elif depth > 0:
187
- return isinstance(lst_obj, list) and all(is_nested_list_of_type(item, element_type, depth - 1) for item in lst_obj)
188
- else:
189
- raise ValueError("Depth can't be negative")
190
 
191
 
192
  def flatten_list(nested_list: list) -> list:
@@ -220,6 +243,27 @@ def compute_f1(p: float, r: float, eps=sys.float_info.epsilon) -> float:
220
  return f1
221
 
222
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
223
  @dataclass
224
  class Scores:
225
  """
 
1
  import statistics
2
+ import string
3
  import sys
4
  from dataclasses import dataclass, field
5
+ from typing import List, Union, Tuple
6
 
7
+ import nltk
8
  import torch
9
  from numpy.typing import NDArray
10
 
 
151
  raise TypeError(f"Incorrect Type for {num_sentences=}")
152
 
153
 
154
+ def is_nested_list_of_type(lst_obj, element_type, depth: int) -> Tuple[bool, str]:
155
  """
156
+ Check if the given object is a nested list of a specific type up to a specified depth.
157
 
158
+ Args:
159
+ - lst_obj: The object to check, expected to be a list or a single element.
160
+ - element_type: The type that each element in the nested list should match.
161
+ - depth (int): The depth of nesting to check. Must be non-negative.
162
 
163
+ Returns:
164
+ - Tuple[bool, str]: A tuple containing:
165
+ - A boolean indicating if lst_obj is a nested list of the specified type up to the given depth.
166
+ - A string containing an error message if the check fails, or an empty string if the check passes.
167
+
168
+ Raises:
169
+ - ValueError: If depth is negative.
170
+
171
+ Example:
172
+ ```python
173
+ # Test cases
174
+ is_nested_list_of_type("test", str, 0) # Returns (True, "")
175
+ is_nested_list_of_type([1, 2, 3], str, 0) # Returns (False, "Element is of type int, expected type str.")
176
+ is_nested_list_of_type(["apple", "banana"], str, 1) # Returns (True, "")
177
+ is_nested_list_of_type([[1, 2], [3, 4]], int, 2) # Returns (True, "")
178
+ is_nested_list_of_type([[1, 2], ["a", "b"]], int, 2) # Returns (False, "Element at index 1 is of incorrect type.")
179
+ is_nested_list_of_type([[[1], [2]], [[3], [4]]], int, 3) # Returns (True, "")
180
+ ```
181
+
182
+ Explanation:
183
+ - The function checks if `lst_obj` is a nested list of elements of type `element_type` up to `depth` levels deep.
184
+ - If `depth` is 0, it checks if `lst_obj` itself is of type `element_type`.
185
+ - If `depth` is greater than 0, it recursively checks each level of nesting to ensure all elements match
186
+ `element_type`.
187
+ - Returns a tuple containing a boolean and an error message. The boolean is `True` if `lst_obj` matches the
188
+ criteria, `False` otherwise. The error message provides details if the check fails.
189
+ - Raises a `ValueError` if `depth` is negative, as depth must be a non-negative integer.
190
+ """
191
+ orig_depth = depth
192
 
193
+ def _is_nested_list_of_type(lst_o, e_type, d) -> Tuple[bool, str]:
194
+ if d == 0:
195
+ if isinstance(lst_o, e_type):
196
+ return True, ""
197
+ else:
198
+ return False, f"Element is of type {type(lst_o).__name__}, expected type {e_type.__name__}."
199
+ elif d > 0:
200
+ if isinstance(lst_o, list):
201
+ for i, item in enumerate(lst_o):
202
+ is_valid, err = _is_nested_list_of_type(item, e_type, d - 1)
203
+ if not is_valid:
204
+ msg = f"Element at index {i} has incorrect type.\n{err}" if d == orig_depth else err
205
+ return False, msg
206
+ return True, ""
207
+ else:
208
+ return False, f"Object is not a list but {type(lst_o)}."
209
+ else:
210
+ raise ValueError("Depth can't be negative")
211
 
212
+ return _is_nested_list_of_type(lst_obj, element_type, depth)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
213
 
214
 
215
  def flatten_list(nested_list: list) -> list:
 
243
  return f1
244
 
245
 
246
+ def sent_tokenize(text: str) -> List[str]:
247
+ """
248
+ Tokenizes the input text into a list of sentences.
249
+
250
+ This function uses the NLTK library's sentence tokenizer to split the input
251
+ text into individual sentences. Leading and trailing whitespace is removed
252
+ from the input text before tokenization. If the input text is empty or consists
253
+ only of whitespace, a list containing an empty string is returned.
254
+
255
+ Args:
256
+ text (str): The input text to be tokenized into sentences.
257
+
258
+ Returns:
259
+ List[str]: A list of sentences tokenized from the input text.
260
+ """
261
+ text = text.strip()
262
+ if text == "":
263
+ return [""]
264
+ return [sent.strip() for sent in nltk.tokenize.sent_tokenize(text)]
265
+
266
+
267
  @dataclass
268
  class Scores:
269
  """