nbansal commited on
Commit
e0e4e28
1 Parent(s): 17b14df

Handled the edge cases and added better error message

Browse files
Files changed (3) hide show
  1. semncg.py +21 -19
  2. tests.py +59 -11
  3. utils.py +55 -34
semncg.py CHANGED
@@ -308,29 +308,31 @@ def _validate_input_format(
308
  >>> _validate_input_format(tokenize_sentences, predictions, references, documents)
309
  """
310
  if not (len(predictions) == len(references) == len(documents)):
311
- raise ValueError("Predictions, References and Documents must have the same length.")
 
 
 
312
 
313
  if len(predictions) == 0:
314
  raise ValueError("Can't have empty inputs")
315
 
316
- def is_list_of_strings_at_depth(lst_obj, depth: int):
317
- return is_nested_list_of_type(lst_obj, element_type=str, depth=depth)
318
-
319
- if tokenize_sentences:
320
- condition = (
321
- is_list_of_strings_at_depth(predictions, 1) and
322
- is_list_of_strings_at_depth(references, 1) and
323
- is_list_of_strings_at_depth(documents, 1)
324
- )
325
- else:
326
- condition = (
327
- is_list_of_strings_at_depth(predictions, 2) and
328
- is_list_of_strings_at_depth(references, 2) and
329
- is_list_of_strings_at_depth(documents, 2)
330
- )
331
-
332
- if not condition:
333
- raise ValueError("Predictions, References and Documents are not valid input format. Refer to documentation.")
334
 
335
 
336
  @evaluate.utils.file_utils.add_start_docstrings(_DESCRIPTION, _KWARGS_DESCRIPTION)
 
308
  >>> _validate_input_format(tokenize_sentences, predictions, references, documents)
309
  """
310
  if not (len(predictions) == len(references) == len(documents)):
311
+ raise ValueError(
312
+ f"Predictions, References and Documents must have the same length. "
313
+ f"Got {len(predictions)} predictions, {len(references)} references and {len(documents)} documents."
314
+ )
315
 
316
  if len(predictions) == 0:
317
  raise ValueError("Can't have empty inputs")
318
 
319
+ def check_format(lst_obj, expected_depth: int, name: str):
320
+ is_valid, error_message = is_nested_list_of_type(lst_obj, element_type=str, depth=expected_depth)
321
+ if not is_valid:
322
+ raise ValueError(f"{name} are not in the expected format.\n"
323
+ f"Error: {error_message}.")
324
+
325
+ try:
326
+ if tokenize_sentences:
327
+ check_format(predictions, expected_depth=1, name="predictions")
328
+ check_format(references, expected_depth=1, name="references")
329
+ check_format(documents, expected_depth=1, name="documents")
330
+ else:
331
+ check_format(predictions, expected_depth=2, name="predictions")
332
+ check_format(references, expected_depth=2, name="references")
333
+ check_format(documents, expected_depth=2, name="documents")
334
+ except ValueError as ve:
335
+ raise ValueError(f"Input validation error: {ve}")
 
336
 
337
 
338
  @evaluate.utils.file_utils.add_start_docstrings(_DESCRIPTION, _KWARGS_DESCRIPTION)
tests.py CHANGED
@@ -139,29 +139,35 @@ class TestUtils(unittest.TestCase):
139
 
140
  def test_is_nested_list_of_type(self):
141
  # Test case: Depth 0, single element matching element_type
142
- self.assertTrue(is_nested_list_of_type("test", str, 0))
143
 
144
  # Test case: Depth 0, single element not matching element_type
145
- self.assertFalse(is_nested_list_of_type("test", int, 0))
 
146
 
147
  # Test case: Depth 1, list of elements matching element_type
148
- self.assertTrue(is_nested_list_of_type(["apple", "banana"], str, 1))
149
 
150
  # Test case: Depth 1, list of elements not matching element_type
151
- self.assertFalse(is_nested_list_of_type([1, 2, 3], str, 1))
 
152
 
153
  # Test case: Depth 0 (Wrong), list of elements matching element_type
154
- self.assertFalse(is_nested_list_of_type([1, 2, 3], str, 0))
 
155
 
156
  # Depth 2
157
- self.assertTrue(is_nested_list_of_type([[1, 2], [3, 4]], int, 2))
158
- self.assertTrue(is_nested_list_of_type([['1', '2'], ['3', '4']], str, 2))
159
- self.assertFalse(is_nested_list_of_type([[1, 2], ["a", "b"]], int, 2))
 
160
 
161
  # Depth 3
162
- self.assertFalse(is_nested_list_of_type([[[1], [2]], [[3], [4]]], list, 3))
163
- self.assertTrue(is_nested_list_of_type([[[1], [2]], [[3], [4]]], int, 3))
 
164
 
 
165
  with self.assertRaises(ValueError):
166
  is_nested_list_of_type([1, 2], int, -1)
167
 
@@ -358,7 +364,7 @@ class TestValidateInputFormat(unittest.TestCase):
358
  _validate_input_format(tokenize_sentences, predictions, references, documents_invalid)
359
 
360
 
361
- class TestSemnCG(unittest.TestCase):
362
  def setUp(self):
363
  self.model_name = "stsb-distilbert-base"
364
  self.metric = SemNCG(self.model_name)
@@ -424,6 +430,48 @@ class TestSemnCG(unittest.TestCase):
424
  with self.assertRaises(ValueError):
425
  self.metric.compute(predictions=predictions, references=references, documents=documents)
426
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
427
 
428
  if __name__ == '__main__':
429
  unittest.main(verbosity=2)
 
139
 
140
  def test_is_nested_list_of_type(self):
141
  # Test case: Depth 0, single element matching element_type
142
+ self.assertEqual(is_nested_list_of_type("test", str, 0), (True, ""))
143
 
144
  # Test case: Depth 0, single element not matching element_type
145
+ is_valid, err_msg = is_nested_list_of_type("test", int, 0)
146
+ self.assertEqual(is_valid, False)
147
 
148
  # Test case: Depth 1, list of elements matching element_type
149
+ self.assertEqual(is_nested_list_of_type(["apple", "banana"], str, 1), (True, ""))
150
 
151
  # Test case: Depth 1, list of elements not matching element_type
152
+ is_valid, err_msg = is_nested_list_of_type([1, 2, 3], str, 1)
153
+ self.assertEqual(is_valid, False)
154
 
155
  # Test case: Depth 0 (Wrong), list of elements matching element_type
156
+ is_valid, err_msg = is_nested_list_of_type([1, 2, 3], str, 0)
157
+ self.assertEqual(is_valid, False)
158
 
159
  # Depth 2
160
+ self.assertEqual(is_nested_list_of_type([[1, 2], [3, 4]], int, 2), (True, ""))
161
+ self.assertEqual(is_nested_list_of_type([['1', '2'], ['3', '4']], str, 2), (True, ""))
162
+ is_valid, err_msg = is_nested_list_of_type([[1, 2], ["a", "b"]], int, 2)
163
+ self.assertEqual(is_valid, False)
164
 
165
  # Depth 3
166
+ is_valid, err_msg = is_nested_list_of_type([[[1], [2]], [[3], [4]]], list, 3)
167
+ self.assertEqual(is_valid, False)
168
+ self.assertEqual(is_nested_list_of_type([[[1], [2]], [[3], [4]]], int, 3), (True, ""))
169
 
170
+ # Test case: Depth is negative, expecting ValueError
171
  with self.assertRaises(ValueError):
172
  is_nested_list_of_type([1, 2], int, -1)
173
 
 
364
  _validate_input_format(tokenize_sentences, predictions, references, documents_invalid)
365
 
366
 
367
+ class TestSemNCG(unittest.TestCase):
368
  def setUp(self):
369
  self.model_name = "stsb-distilbert-base"
370
  self.metric = SemNCG(self.model_name)
 
430
  with self.assertRaises(ValueError):
431
  self.metric.compute(predictions=predictions, references=references, documents=documents)
432
 
433
+ def test_bad_inputs(self):
434
+ def _call_metric(preds, refs, docs, tok):
435
+ with self.assertRaises(Exception) as ctx:
436
+ _ = self.metric.compute(
437
+ predictions=preds,
438
+ references=refs,
439
+ documents=docs,
440
+ tokenize_sentences=tok,
441
+ pre_compute_embeddings=True,
442
+ )
443
+ print(f"Raised Exception with message: {ctx.exception}")
444
+ return ""
445
+
446
+ # None Inputs
447
+ # Case I
448
+ tokenize_sentences = True
449
+ predictions = [None]
450
+ references = ["A cat was sitting on a mat."]
451
+ documents = ["There was a cat on a mat."]
452
+ print(f"Case I\n{_call_metric(predictions, references, documents, tokenize_sentences)}\n")
453
+
454
+ # Case II
455
+ tokenize_sentences = False
456
+ predictions = [["A cat was sitting on a mat.", None]]
457
+ references = [["A cat was sitting on a mat.", "A cat was sitting on a mat."]]
458
+ documents = [["There was a cat on a mat.", "There was a cat on a mat."]]
459
+ print(f"Case II\n{_call_metric(predictions, references, documents, tokenize_sentences)}\n")
460
+
461
+ # Empty Input
462
+ tokenize_sentences = True
463
+ predictions = []
464
+ references = ["A cat was sitting on a mat."]
465
+ documents = ["There was a cat on a mat."]
466
+ print(f"Case: Empty Input\n{_call_metric(predictions, references, documents, tokenize_sentences)}\n")
467
+
468
+ # Empty String Input
469
+ tokenize_sentences = True
470
+ predictions = [""]
471
+ references = ["A cat was sitting on a mat."]
472
+ documents = ["There was a cat on a mat."]
473
+ print(f"Case: Empty String Input\n{_call_metric(predictions, references, documents, tokenize_sentences)}\n")
474
+
475
 
476
  if __name__ == '__main__':
477
  unittest.main(verbosity=2)
utils.py CHANGED
@@ -1,5 +1,5 @@
1
  import string
2
- from typing import List, Union
3
 
4
  import nltk
5
  import torch
@@ -167,45 +167,66 @@ def flatten_list(nested_list: list) -> list:
167
  return flat_list
168
 
169
 
170
- def is_nested_list_of_type(lst_obj, element_type, depth: int) -> bool:
171
  """
172
- Check if the given object is a nested list of a specific type up to a specified depth.
173
 
174
- Args:
175
- - lst_obj: The object to check, expected to be a list or a single element.
176
- - element_type: The type that each element in the nested list should match.
177
- - depth (int): The depth of nesting to check. Must be non-negative.
178
-
179
- Returns:
180
- - bool: True if lst_obj is a nested list of the specified type up to the given depth, False otherwise.
181
 
182
- Raises:
183
- - ValueError: If depth is negative.
 
 
184
 
185
- Example:
186
- ```python
187
- # Test cases
188
- is_nested_list_of_type("test", str, 0) # Returns True
189
- is_nested_list_of_type([1, 2, 3], str, 0) # Returns False
190
- is_nested_list_of_type(["apple", "banana"], str, 1) # Returns True
191
- is_nested_list_of_type([[1, 2], [3, 4]], int, 2) # Returns True
192
- is_nested_list_of_type([[1, 2], ["a", "b"]], int, 2) # Returns False
193
- is_nested_list_of_type([[[1], [2]], [[3], [4]]], int, 3) # Returns True
194
- ```
195
 
196
- Explanation:
197
- - The function checks if `lst_obj` is a nested list of elements of type `element_type` up to `depth` levels deep.
198
- - If `depth` is 0, it checks if `lst_obj` itself is of type `element_type`.
199
- - If `depth` is greater than 0, it recursively checks each level of nesting to ensure all elements match `element_type`.
200
- - Raises a `ValueError` if `depth` is negative, as depth must be a non-negative integer.
 
 
 
 
 
 
 
 
 
 
 
 
 
 
201
  """
202
- if depth == 0:
203
- return isinstance(lst_obj, element_type)
204
- elif depth > 0:
205
- return isinstance(lst_obj, list) and all(
206
- is_nested_list_of_type(item, element_type, depth - 1) for item in lst_obj)
207
- else:
208
- raise ValueError("Depth can't be negative")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
209
 
210
 
211
  def slice_embeddings(embeddings: NDArray, num_sentences: NumSentencesType) -> EmbeddingSlicesType:
 
1
  import string
2
+ from typing import List, Union, Tuple
3
 
4
  import nltk
5
  import torch
 
167
  return flat_list
168
 
169
 
170
+ def is_nested_list_of_type(lst_obj, element_type, depth: int) -> Tuple[bool, str]:
171
  """
172
+ Check if the given object is a nested list of a specific type up to a specified depth.
173
 
174
+ Args:
175
+ - lst_obj: The object to check, expected to be a list or a single element.
176
+ - element_type: The type that each element in the nested list should match.
177
+ - depth (int): The depth of nesting to check. Must be non-negative.
 
 
 
178
 
179
+ Returns:
180
+ - Tuple[bool, str]: A tuple containing:
181
+ - A boolean indicating if lst_obj is a nested list of the specified type up to the given depth.
182
+ - A string containing an error message if the check fails, or an empty string if the check passes.
183
 
184
+ Raises:
185
+ - ValueError: If depth is negative.
 
 
 
 
 
 
 
 
186
 
187
+ Example:
188
+ ```python
189
+ # Test cases
190
+ is_nested_list_of_type("test", str, 0) # Returns (True, "")
191
+ is_nested_list_of_type([1, 2, 3], str, 0) # Returns (False, "Element is of type int, expected type str.")
192
+ is_nested_list_of_type(["apple", "banana"], str, 1) # Returns (True, "")
193
+ is_nested_list_of_type([[1, 2], [3, 4]], int, 2) # Returns (True, "")
194
+ is_nested_list_of_type([[1, 2], ["a", "b"]], int, 2) # Returns (False, "Element at index 1 is of incorrect type.")
195
+ is_nested_list_of_type([[[1], [2]], [[3], [4]]], int, 3) # Returns (True, "")
196
+ ```
197
+
198
+ Explanation:
199
+ - The function checks if `lst_obj` is a nested list of elements of type `element_type` up to `depth` levels deep.
200
+ - If `depth` is 0, it checks if `lst_obj` itself is of type `element_type`.
201
+ - If `depth` is greater than 0, it recursively checks each level of nesting to ensure all elements match
202
+ `element_type`.
203
+ - Returns a tuple containing a boolean and an error message. The boolean is `True` if `lst_obj` matches the
204
+ criteria, `False` otherwise. The error message provides details if the check fails.
205
+ - Raises a `ValueError` if `depth` is negative, as depth must be a non-negative integer.
206
  """
207
+ orig_depth = depth
208
+
209
+ def _is_nested_list_of_type(lst_o, e_type, d) -> Tuple[bool, str]:
210
+ if d == 0:
211
+ if isinstance(lst_o, e_type):
212
+ return True, ""
213
+ else:
214
+ return False, f"Element is of type {type(lst_o).__name__}, expected type {e_type.__name__}."
215
+ elif d > 0:
216
+ if isinstance(lst_o, list):
217
+ for i, item in enumerate(lst_o):
218
+ is_valid, err = _is_nested_list_of_type(item, e_type, d - 1)
219
+ if not is_valid:
220
+ msg = (f"Element at index {i} has incorrect type.\nGiven Element at index {i}: {lst_o[i]}"
221
+ f"\n{err}") if d == orig_depth else err
222
+ return False, msg
223
+ return True, ""
224
+ else:
225
+ return False, f"Object is not a list but {type(lst_o)}."
226
+ else:
227
+ raise ValueError("Depth can't be negative")
228
+
229
+ return _is_nested_list_of_type(lst_obj, element_type, depth)
230
 
231
 
232
  def slice_embeddings(embeddings: NDArray, num_sentences: NumSentencesType) -> EmbeddingSlicesType: