nbansal commited on
Commit
a249916
1 Parent(s): de5dcb7

Major refactoring and added test cases

Browse files
Files changed (6) hide show
  1. .gitignore +1 -0
  2. encoder_models.py +108 -0
  3. semf1.py +74 -69
  4. tests.py +179 -17
  5. type_aliases.py +10 -0
  6. utils.py +78 -10
.gitignore ADDED
@@ -0,0 +1 @@
 
 
1
+ __pycache__/
encoder_models.py ADDED
@@ -0,0 +1,108 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import abc
2
+ from typing import List, Union
3
+
4
+ from numpy.typing import NDArray
5
+ from sentence_transformers import SentenceTransformer
6
+
7
+ from type_aliases import ENCODER_DEVICE_TYPE
8
+
9
+
10
+ class Encoder(abc.ABC):
11
+ @abc.abstractmethod
12
+ def encode(self, prediction: List[str]) -> NDArray:
13
+ """
14
+ Abstract method to encode a list of sentences into sentence embeddings.
15
+
16
+ Args:
17
+ prediction (List[str]): List of sentences to encode.
18
+
19
+ Returns:
20
+ NDArray: Array of sentence embeddings with shape (num_sentences, embedding_dim).
21
+
22
+ Raises:
23
+ NotImplementedError: If the method is not implemented in the subclass.
24
+ """
25
+ raise NotImplementedError("Method 'encode' must be implemented in subclass.")
26
+
27
+
28
+ class USE(Encoder):
29
+ def __init__(self):
30
+ pass
31
+
32
+ def encode(self, prediction: List[str]) -> NDArray:
33
+ pass
34
+
35
+
36
+ class SBertEncoder(Encoder):
37
+ def __init__(self, model_name: str, device: ENCODER_DEVICE_TYPE, batch_size: int, verbose: bool):
38
+ """
39
+ Initialize SBertEncoder instance.
40
+
41
+ Args:
42
+ model_name (str): Name or path of the Sentence Transformer model.
43
+ device (Union[str, int, List[Union[str, int]]]): Device specification for encoding
44
+ batch_size (int): Batch size for encoding.
45
+ verbose (bool): Whether to print verbose information during encoding.
46
+ """
47
+ self.model = SentenceTransformer(model_name)
48
+ self.device = device
49
+ self.batch_size = batch_size
50
+ self.verbose = verbose
51
+
52
+ def encode(self, prediction: List[str]) -> NDArray:
53
+ """
54
+ Encode a list of sentences into sentence embeddings.
55
+
56
+ Args:
57
+ prediction (List[str]): List of sentences to encode.
58
+
59
+ Returns:
60
+ NDArray: Array of sentence embeddings with shape (num_sentences, embedding_dim).
61
+ """
62
+
63
+ # SBert output is always Batch x Dim
64
+ if isinstance(self.device, list):
65
+ # Use multiprocess encoding for list of devices
66
+ pool = self.model.start_multi_process_pool(target_devices=self.device)
67
+ embeddings = self.model.encode_multi_process(prediction, pool=pool, batch_size=self.batch_size)
68
+ self.model.stop_multi_process_pool(pool)
69
+ else:
70
+ # Single device encoding
71
+ embeddings = self.model.encode(
72
+ prediction,
73
+ device=self.device,
74
+ batch_size=self.batch_size,
75
+ show_progress_bar=self.verbose,
76
+ )
77
+
78
+ return embeddings
79
+
80
+
81
+ def get_encoder(model_name: str, device: ENCODER_DEVICE_TYPE, batch_size: int, verbose: bool) -> Encoder:
82
+ """
83
+ Get the encoder instance based on the specified model name.
84
+
85
+ Args:
86
+ model_name (str): Name of the model to instantiate
87
+ Options: [pv1, stsb, use]
88
+ pv1 - paraphrase-distilroberta-base-v1 (Default)
89
+ stsb - stsb-roberta-large
90
+ use - Universal Sentence Encoder
91
+ device (Union[str, int, List[Union[str, int]]): Device specification for the encoder
92
+ (e.g., "cuda", 0 for GPU, "cpu").
93
+ batch_size (int): Batch size for encoding.
94
+ verbose (bool): Whether to print verbose information during encoder initialization.
95
+
96
+ Returns:
97
+ Encoder: Instance of the selected encoder based on the model_name.
98
+
99
+ Raises:
100
+ ValueError: If an unsupported model_name is provided.
101
+ """
102
+
103
+ # TODO: chnage this when changing the TF model
104
+ if model_name == "use":
105
+ return SBertEncoder("sentence-transformers/use-cmlm-multilingual", device, batch_size, verbose)
106
+ # return USE()
107
+ else:
108
+ return SBertEncoder(model_name, device, batch_size, verbose)
semf1.py CHANGED
@@ -14,21 +14,19 @@
14
  # TODO: Add test cases, Remove tokenize_sentences flag since it can be determined from the input itself.
15
  """Sem-F1 metric"""
16
 
17
- import abc
18
- import sys
19
- from typing import List, Optional, Tuple, Union
20
 
21
  import datasets
22
  import evaluate
23
  import nltk
24
  import numpy as np
25
  from numpy.typing import NDArray
26
- from sentence_transformers import SentenceTransformer
27
  from sklearn.metrics.pairwise import cosine_similarity
28
- import torch
29
- from tqdm import tqdm
30
 
31
- from utils import is_list_of_strings_at_depth, Scores, slice_embeddings, flatten_list
 
 
32
 
33
  _CITATION = """\
34
  @inproceedings{bansal-etal-2022-sem,
@@ -123,80 +121,80 @@ Examples:
123
  [0.77, 0.56]
124
  """
125
 
126
- _PREDICTION_TYPE = Union[List[str], List[List[str]]]
127
- _REFERENCE_TYPE = Union[List[str], List[List[str]], List[List[List[str]]]]
128
 
 
 
 
129
 
130
- class Encoder(metaclass=abc.ABCMeta):
131
- @abc.abstractmethod
132
- def encode(self, prediction: List[str]) -> NDArray:
133
- pass
134
 
 
 
 
 
 
 
 
135
 
136
- class USE(Encoder):
137
- def __init__(self):
138
- pass
139
 
140
- def encode(self, prediction: List[str]) -> NDArray:
141
- pass
142
 
 
 
 
143
 
144
- class SBertEncoder(Encoder):
145
- def __init__(self, model_name: str, device: Union[str, int], batch_size: int):
146
- self.model = SentenceTransformer(model_name)
147
- self.device = device
148
- self.batch_size = batch_size
149
 
150
- def encode(self, prediction: List[str]) -> NDArray:
151
- """Returns sentence embeddings of dim: Batch x Dim"""
152
- # SBert output is always Batch x Dim
153
- return self.model.encode(prediction, device=self.device, batch_size=self.batch_size)
154
 
 
 
 
 
 
 
 
 
155
 
156
- def _get_encoder(model_name: str, device: Union[str, int], batch_size: int) -> Encoder:
157
- if model_name == "use":
158
- return SBertEncoder(model_name, device, batch_size)
159
- # return USE() # TODO: This will change depending on PyTorch USE VS TF USE model
160
- else:
161
- return SBertEncoder(model_name, device, batch_size)
162
 
 
 
163
 
164
- def _compute_cosine_similarity(pred_embeds: NDArray, ref_embeds: NDArray) -> Tuple[float, float]:
165
- cosine_scores = cosine_similarity(pred_embeds, ref_embeds)
166
- precision_per_sentence_sim = np.max(cosine_scores, axis=-1)
167
- recall_per_sentence_sim = np.max(cosine_scores, axis=0)
168
- return np.mean(precision_per_sentence_sim).item(), np.mean(recall_per_sentence_sim).item()
169
 
 
 
 
170
 
171
- def _get_gpu(gpu: Union[bool, int]) -> Union[str, int]:
172
- # Ensure gpu index is within the range of total available gpus
173
- gpu_available = torch.cuda.is_available()
174
- if gpu_available:
175
- gpu_count = torch.cuda.device_count()
176
- if isinstance(gpu, int) and gpu >= gpu_count:
177
- raise ValueError(
178
- f"There are {gpu_count} gpus available. Provide the correct gpu index. You provided: {gpu}"
179
- )
180
 
181
- # get the device
182
- if gpu is False:
183
- device = "cpu"
184
- elif gpu is True and gpu_available:
185
- device = 0 # by default run on device 0
186
- elif isinstance(gpu, int):
187
- device = gpu
188
- else: # This will never happen
189
- raise ValueError(f"gpu must be bool or int. Provided value: {gpu}")
190
 
191
- return device
 
192
 
 
 
 
193
 
194
- def _validate_input_format(
195
- tokenize_sentences: bool,
196
- multi_references: bool,
197
- predictions: _PREDICTION_TYPE,
198
- references: _REFERENCE_TYPE,
199
- ):
200
  if tokenize_sentences and multi_references:
201
  condition = is_list_of_strings_at_depth(predictions, 1) and is_list_of_strings_at_depth(references, 2)
202
  elif not tokenize_sentences and multi_references:
@@ -215,7 +213,7 @@ class SemF1(evaluate.Metric):
215
  _MODEL_TYPE_TO_NAME = {
216
  "pv1": "paraphrase-distilroberta-base-v1",
217
  "stsb": "stsb-roberta-large",
218
- "use": "sentence-transformers/use-cmlm-multilingual", # TODO: check PyTorch USE VS TF USE
219
  }
220
 
221
  def _info(self):
@@ -275,7 +273,7 @@ class SemF1(evaluate.Metric):
275
 
276
  def _get_model_name(self, model_type: Optional[str] = None) -> str:
277
  if model_type is None:
278
- model_type = "pv1" # TODO: Change it to use
279
 
280
  if model_type not in self._MODEL_TYPE_TO_NAME.keys():
281
  raise ValueError(f"Provide a correct model_type.\n"
@@ -291,7 +289,6 @@ class SemF1(evaluate.Metric):
291
  # if not nltk.data.find("tokenizers/punkt"): # TODO: check why it is not working
292
  # pass
293
 
294
-
295
  def _compute(
296
  self,
297
  predictions,
@@ -299,8 +296,9 @@ class SemF1(evaluate.Metric):
299
  model_type: Optional[str] = None,
300
  tokenize_sentences: bool = True,
301
  multi_references: bool = False,
302
- gpu: Union[bool, int] = False,
303
  batch_size: int = 32,
 
304
  ) -> List[Scores]:
305
  """
306
  Compute precision, recall, and F1 scores for given predictions and references.
@@ -308,10 +306,15 @@ class SemF1(evaluate.Metric):
308
  :param predictions
309
  :param references
310
  :param model_type: Type of model to use for encoding.
 
 
 
 
311
  :param tokenize_sentences: Flag to sentence tokenize the document.
312
  :param multi_references: Flag to indicate multiple references.
313
  :param gpu: GPU device to use.
314
  :param batch_size: Batch size for encoding.
 
315
 
316
  :return: List of Scores dataclass with precision, recall, and F1 scores.
317
  """
@@ -320,11 +323,13 @@ class SemF1(evaluate.Metric):
320
  _validate_input_format(tokenize_sentences, multi_references, predictions, references)
321
 
322
  # Get GPU
323
- device = _get_gpu(gpu)
 
 
324
 
325
  # Get the encoder model
326
  model_name = self._get_model_name(model_type)
327
- encoder = _get_encoder(model_name, device=device, batch_size=batch_size)
328
 
329
  # We'll handle the single reference and multi-reference case same way. So change the data format accordingly
330
  if not multi_references:
 
14
  # TODO: Add test cases, Remove tokenize_sentences flag since it can be determined from the input itself.
15
  """Sem-F1 metric"""
16
 
17
+ from functools import partial
18
+ from typing import List, Optional, Tuple
 
19
 
20
  import datasets
21
  import evaluate
22
  import nltk
23
  import numpy as np
24
  from numpy.typing import NDArray
 
25
  from sklearn.metrics.pairwise import cosine_similarity
 
 
26
 
27
+ from encoder_models import get_encoder
28
+ from type_aliases import DEVICE_TYPE, PREDICTION_TYPE, REFERENCE_TYPE
29
+ from utils import is_nested_list_of_type, Scores, slice_embeddings, flatten_list, get_gpu
30
 
31
  _CITATION = """\
32
  @inproceedings{bansal-etal-2022-sem,
 
121
  [0.77, 0.56]
122
  """
123
 
 
 
124
 
125
+ def _compute_cosine_similarity(pred_embeds: NDArray, ref_embeds: NDArray) -> Tuple[float, float]:
126
+ """
127
+ Compute precision and recall based on cosine similarity between predicted and reference embeddings.
128
 
129
+ Args:
130
+ pred_embeds (NDArray): Predicted embeddings (shape: [num_pred, embedding_dim]).
131
+ ref_embeds (NDArray): Reference embeddings (shape: [num_ref, embedding_dim]).
 
132
 
133
+ Returns:
134
+ Tuple[float, float]: Precision and recall based on cosine similarity scores.
135
+ Precision: Average maximum cosine similarity score per predicted embedding.
136
+ Recall: Average maximum cosine similarity score per reference embedding.
137
+ """
138
+ # Compute cosine similarity between predicted and reference embeddings
139
+ cosine_scores = cosine_similarity(pred_embeds, ref_embeds)
140
 
141
+ # Compute precision per predicted embedding
142
+ precision_per_sentence_sim = np.max(cosine_scores, axis=-1)
 
143
 
144
+ # Compute recall per reference embedding
145
+ recall_per_sentence_sim = np.max(cosine_scores, axis=0)
146
 
147
+ # Calculate mean precision and recall scores
148
+ precision = np.mean(precision_per_sentence_sim).item()
149
+ recall = np.mean(recall_per_sentence_sim).item()
150
 
151
+ return precision, recall
 
 
 
 
152
 
 
 
 
 
153
 
154
+ def _validate_input_format(
155
+ tokenize_sentences: bool,
156
+ multi_references: bool,
157
+ predictions: PREDICTION_TYPE,
158
+ references: REFERENCE_TYPE,
159
+ ):
160
+ """
161
+ Validate the format of predictions and references based on specified criteria.
162
 
163
+ Args:
164
+ - tokenize_sentences (bool): Flag indicating whether sentences should be tokenized.
165
+ - multi_references (bool): Flag indicating whether multiple references are provided.
166
+ - predictions (PREDICTION_TYPE): Predictions to validate.
167
+ - references (REFERENCE_TYPE): References to validate.
 
168
 
169
+ Raises:
170
+ - ValueError: If the format of predictions or references does not meet the specified criteria.
171
 
172
+ Validation Criteria:
173
+ The function validates predictions and references based on the following conditions:
174
+ 1. If `tokenize_sentences` is True and `multi_references` is True:
175
+ - Predictions must be a list of strings (`is_list_of_strings_at_depth(predictions, 1)`).
176
+ - References must be a list of list of strings (`is_list_of_strings_at_depth(references, 2)`).
177
 
178
+ 2. If `tokenize_sentences` is False and `multi_references` is True:
179
+ - Predictions must be a list of list of strings (`is_list_of_strings_at_depth(predictions, 2)`).
180
+ - References must be a list of list of list of strings (`is_list_of_strings_at_depth(references, 3)`).
181
 
182
+ 3. If `tokenize_sentences` is True and `multi_references` is False:
183
+ - Predictions must be a list of strings (`is_list_of_strings_at_depth(predictions, 1)`).
184
+ - References must be a list of strings (`is_list_of_strings_at_depth(references, 1)`).
 
 
 
 
 
 
185
 
186
+ 4. If `tokenize_sentences` is False and `multi_references` is False:
187
+ - Predictions must be a list of list of strings (`is_list_of_strings_at_depth(predictions, 2)`).
188
+ - References must be a list of list of strings (`is_list_of_strings_at_depth(references, 2)`).
 
 
 
 
 
 
189
 
190
+ The function checks these conditions and raises a ValueError if any condition is not met,
191
+ indicating that predictions or references are not in the valid input format.
192
 
193
+ Note:
194
+ - `PREDICTION_TYPE` and `REFERENCE_TYPE` are defined at the top of the file
195
+ """
196
 
197
+ is_list_of_strings_at_depth = partial(is_nested_list_of_type, element_type=str)
 
 
 
 
 
198
  if tokenize_sentences and multi_references:
199
  condition = is_list_of_strings_at_depth(predictions, 1) and is_list_of_strings_at_depth(references, 2)
200
  elif not tokenize_sentences and multi_references:
 
213
  _MODEL_TYPE_TO_NAME = {
214
  "pv1": "paraphrase-distilroberta-base-v1",
215
  "stsb": "stsb-roberta-large",
216
+ "use": "use", # "sentence-transformers/use-cmlm-multilingual", # TODO: check PyTorch USE VS TF USE
217
  }
218
 
219
  def _info(self):
 
273
 
274
  def _get_model_name(self, model_type: Optional[str] = None) -> str:
275
  if model_type is None:
276
+ model_type = "use"
277
 
278
  if model_type not in self._MODEL_TYPE_TO_NAME.keys():
279
  raise ValueError(f"Provide a correct model_type.\n"
 
289
  # if not nltk.data.find("tokenizers/punkt"): # TODO: check why it is not working
290
  # pass
291
 
 
292
  def _compute(
293
  self,
294
  predictions,
 
296
  model_type: Optional[str] = None,
297
  tokenize_sentences: bool = True,
298
  multi_references: bool = False,
299
+ gpu: DEVICE_TYPE = False,
300
  batch_size: int = 32,
301
+ verbose: bool = False,
302
  ) -> List[Scores]:
303
  """
304
  Compute precision, recall, and F1 scores for given predictions and references.
 
306
  :param predictions
307
  :param references
308
  :param model_type: Type of model to use for encoding.
309
+ Options: [pv1, stsb, use]
310
+ pv1 - paraphrase-distilroberta-base-v1 (Default)
311
+ stsb - stsb-roberta-large
312
+ use - Universal Sentence Encoder
313
  :param tokenize_sentences: Flag to sentence tokenize the document.
314
  :param multi_references: Flag to indicate multiple references.
315
  :param gpu: GPU device to use.
316
  :param batch_size: Batch size for encoding.
317
+ :param verbose: Flag to indicate verbose output.
318
 
319
  :return: List of Scores dataclass with precision, recall, and F1 scores.
320
  """
 
323
  _validate_input_format(tokenize_sentences, multi_references, predictions, references)
324
 
325
  # Get GPU
326
+ device = get_gpu(gpu)
327
+ if verbose:
328
+ print(f"Using devices: {device}")
329
 
330
  # Get the encoder model
331
  model_name = self._get_model_name(model_type)
332
+ encoder = get_encoder(model_name, device=device, batch_size=batch_size, verbose=verbose)
333
 
334
  # We'll handle the single reference and multi-reference case same way. So change the data format accordingly
335
  if not multi_references:
tests.py CHANGED
@@ -1,17 +1,179 @@
1
- test_cases = [
2
- {
3
- "predictions": [0, 0],
4
- "references": [1, 1],
5
- "result": {"metric_score": 0}
6
- },
7
- {
8
- "predictions": [1, 1],
9
- "references": [1, 1],
10
- "result": {"metric_score": 1}
11
- },
12
- {
13
- "predictions": [1, 0],
14
- "references": [1, 1],
15
- "result": {"metric_score": 0.5}
16
- }
17
- ]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import statistics
2
+ import unittest
3
+
4
+ import numpy as np
5
+ import torch
6
+ from sentence_transformers import SentenceTransformer
7
+
8
+ from encoder_models import SBertEncoder, get_encoder
9
+ from utils import get_gpu, slice_embeddings, is_nested_list_of_type, flatten_list, compute_f1, Scores
10
+
11
+
12
+ class TestUtils(unittest.TestCase):
13
+ def test_get_gpu(self):
14
+ gpu_count = torch.cuda.device_count()
15
+ gpu_available = torch.cuda.is_available()
16
+
17
+ # Test single boolean input
18
+ self.assertEqual(get_gpu(True), 0 if gpu_available else "cpu")
19
+ self.assertEqual(get_gpu(False), "cpu")
20
+
21
+ # Test single string input
22
+ self.assertEqual(get_gpu("cpu"), "cpu")
23
+ self.assertEqual(get_gpu("gpu"), 0 if gpu_available else "cpu")
24
+ self.assertEqual(get_gpu("cuda"), 0 if gpu_available else "cpu")
25
+
26
+ # Test single integer input
27
+ self.assertEqual(get_gpu(0), 0 if gpu_available else "cpu")
28
+ self.assertEqual(get_gpu(1), 1 if gpu_available else "cpu")
29
+
30
+ # Test list input with unique elements
31
+ self.assertEqual(get_gpu([True, "cpu", 0]), [0, "cpu"] if gpu_available else ["cpu", "cpu", "cpu"])
32
+
33
+ # Test list input with duplicate elements
34
+ self.assertEqual(get_gpu([0, 0, "gpu"]), [0] if gpu_available else ["cpu", "cpu", "cpu"])
35
+
36
+ # Test list input with duplicate elements of different types
37
+ self.assertEqual(get_gpu([True, 0, "gpu"]), [0] if gpu_available else ["cpu", "cpu", "cpu"])
38
+
39
+ # Test list input with all integers
40
+ self.assertEqual(get_gpu(list(range(gpu_count))),
41
+ list(range(gpu_count)) if gpu_available else gpu_count * ["cpu"])
42
+
43
+ with self.assertRaises(ValueError):
44
+ get_gpu("invalid")
45
+
46
+ with self.assertRaises(ValueError):
47
+ get_gpu(torch.cuda.device_count())
48
+
49
+ def test_slice_embeddings(self):
50
+ embeddings = np.random.rand(10, 5)
51
+ num_sentences = [3, 2, 5]
52
+ expected_output = [embeddings[:3], embeddings[3:5], embeddings[5:]]
53
+ self.assertTrue(
54
+ all(np.array_equal(a, b) for a, b in zip(slice_embeddings(embeddings, num_sentences),
55
+ expected_output))
56
+ )
57
+
58
+ num_sentences_nested = [[2, 1], [3, 4]]
59
+ expected_output_nested = [[embeddings[:2], embeddings[2:3]], [embeddings[3:6], embeddings[6:]]]
60
+ self.assertTrue(
61
+ slice_embeddings(embeddings, num_sentences_nested), expected_output_nested
62
+ )
63
+
64
+ with self.assertRaises(TypeError):
65
+ slice_embeddings(embeddings, "invalid")
66
+
67
+ def test_is_nested_list_of_type(self):
68
+ # Test case: Depth 0, single element matching element_type
69
+ self.assertTrue(is_nested_list_of_type("test", str, 0))
70
+
71
+ # Test case: Depth 0, single element not matching element_type
72
+ self.assertFalse(is_nested_list_of_type("test", int, 0))
73
+
74
+ # Test case: Depth 1, list of elements matching element_type
75
+ self.assertTrue(is_nested_list_of_type(["apple", "banana"], str, 1))
76
+
77
+ # Test case: Depth 1, list of elements not matching element_type
78
+ self.assertFalse(is_nested_list_of_type([1, 2, 3], str, 1))
79
+
80
+ # Test case: Depth 0 (Wrong), list of elements matching element_type
81
+ self.assertFalse(is_nested_list_of_type([1, 2, 3], str, 0))
82
+
83
+ # Depth 2
84
+ self.assertTrue(is_nested_list_of_type([[1, 2], [3, 4]], int, 2))
85
+ self.assertTrue(is_nested_list_of_type([['1', '2'], ['3', '4']], str, 2))
86
+ self.assertFalse(is_nested_list_of_type([[1, 2], ["a", "b"]], int, 2))
87
+
88
+ # Depth 3
89
+ self.assertFalse(is_nested_list_of_type([[[1], [2]], [[3], [4]]], list, 3))
90
+ self.assertTrue(is_nested_list_of_type([[[1], [2]], [[3], [4]]], int, 3))
91
+
92
+ with self.assertRaises(ValueError):
93
+ is_nested_list_of_type([1, 2], int, -1)
94
+
95
+ def test_flatten_list(self):
96
+ self.assertEqual(flatten_list([1, [2, 3], [[4], 5]]), [1, 2, 3, 4, 5])
97
+ self.assertEqual(flatten_list([]), [])
98
+ self.assertEqual(flatten_list([1, 2, 3]), [1, 2, 3])
99
+ self.assertEqual(flatten_list([[[[1]]]]), [1])
100
+
101
+ def test_compute_f1(self):
102
+ self.assertAlmostEqual(compute_f1(0.5, 0.5), 0.5)
103
+ self.assertAlmostEqual(compute_f1(1, 0), 0.0)
104
+ self.assertAlmostEqual(compute_f1(0, 1), 0.0)
105
+ self.assertAlmostEqual(compute_f1(1, 1), 1.0)
106
+
107
+ def test_scores(self):
108
+ scores = Scores(precision=0.8, recall=[0.7, 0.9])
109
+ self.assertAlmostEqual(scores.f1, compute_f1(0.8, statistics.fmean([0.7, 0.9])))
110
+
111
+
112
+ class TestSBertEncoder(unittest.TestCase):
113
+ def setUp(self, device=None):
114
+ if device is None:
115
+ self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
116
+ else:
117
+ self.device = device
118
+ self.model_name = "stsb-roberta-large"
119
+ self.batch_size = 8
120
+ self.verbose = False
121
+ self.encoder = SBertEncoder(self.model_name, self.device, self.batch_size, self.verbose)
122
+
123
+ def test_initialization(self):
124
+ self.assertIsInstance(self.encoder.model, SentenceTransformer)
125
+ self.assertEqual(self.encoder.device, self.device)
126
+ self.assertEqual(self.encoder.batch_size, self.batch_size)
127
+ self.assertEqual(self.encoder.verbose, self.verbose)
128
+
129
+ def test_encode_single_device(self):
130
+ sentences = ["This is a test sentence.", "Here is another sentence."]
131
+ embeddings = self.encoder.encode(sentences)
132
+ self.assertIsInstance(embeddings, np.ndarray)
133
+ self.assertEqual(embeddings.shape[0], len(sentences))
134
+ self.assertEqual(embeddings.shape[1], self.encoder.model.get_sentence_embedding_dimension())
135
+
136
+ def test_encode_multi_device(self):
137
+ if torch.cuda.device_count() < 2:
138
+ self.skipTest("Multi-GPU test requires at least 2 GPUs.")
139
+ else:
140
+ devices = ["cuda:0", "cuda:1"]
141
+ self.setUp(devices)
142
+ sentences = ["This is a test sentence.", "Here is another sentence.", "This is a test sentence."]
143
+ embeddings = self.encoder.encode(sentences)
144
+ self.assertIsInstance(embeddings, np.ndarray)
145
+ self.assertEqual(embeddings.shape[0], 3)
146
+ self.assertEqual(embeddings.shape[1], self.encoder.model.get_sentence_embedding_dimension())
147
+
148
+
149
+ class TestGetEncoder(unittest.TestCase):
150
+ def test_get_sbert_encoder(self):
151
+ model_name = "stsb-roberta-large"
152
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
153
+ batch_size = 8
154
+ verbose = False
155
+
156
+ encoder = get_encoder(model_name, device, batch_size, verbose)
157
+ self.assertIsInstance(encoder, SBertEncoder)
158
+ self.assertEqual(encoder.device, device)
159
+ self.assertEqual(encoder.batch_size, batch_size)
160
+ self.assertEqual(encoder.verbose, verbose)
161
+
162
+ def test_get_use_encoder(self):
163
+ model_name = "use"
164
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
165
+ batch_size = 8
166
+ verbose = False
167
+
168
+ encoder = get_encoder(model_name, device, batch_size, verbose)
169
+ self.assertIsInstance(encoder, SBertEncoder) # SBertEncoder is returned for "use" for now
170
+ # Uncomment below when implementing USE class
171
+ # self.assertIsInstance(encoder, USE)
172
+ # self.assertEqual(encoder.model_name, model_name)
173
+ # self.assertEqual(encoder.device, device)
174
+ # self.assertEqual(encoder.batch_size, batch_size)
175
+ # self.assertEqual(encoder.verbose, verbose)
176
+
177
+
178
+ if __name__ == '__main__':
179
+ unittest.main()
type_aliases.py ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import List, Union
2
+
3
+ from numpy.typing import NDArray
4
+
5
+ NumSentencesType = Union[List[int], List[List[int]]]
6
+ EmbeddingSlicesType = Union[List[NDArray], List[List[NDArray]]]
7
+ PREDICTION_TYPE = Union[List[str], List[List[str]]]
8
+ REFERENCE_TYPE = Union[List[str], List[List[str]], List[List[List[str]]]]
9
+ DEVICE_TYPE = Union[bool, str, int, List[Union[str, int]]]
10
+ ENCODER_DEVICE_TYPE = Union[str, int, List[Union[str, int]]]
utils.py CHANGED
@@ -1,13 +1,81 @@
1
- from dataclasses import dataclass
2
  import statistics
3
  import sys
 
4
  from typing import List, Union
5
 
 
6
  from numpy.typing import NDArray
7
 
 
8
 
9
- NumSentencesType = Union[List[int], List[List[int]]]
10
- EmbeddingSlicesType = Union[List[NDArray], List[List[NDArray]]]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
11
 
12
 
13
  def slice_embeddings(embeddings: NDArray, num_sentences: NumSentencesType) -> EmbeddingSlicesType:
@@ -22,10 +90,10 @@ def slice_embeddings(embeddings: NDArray, num_sentences: NumSentencesType) -> Em
22
  result, _ = _slice_embeddings(0, num_sentences)
23
  return result
24
  elif isinstance(num_sentences, list) and all(
25
- isinstance(sublist, list) and all(
26
- isinstance(item, int) for item in sublist
27
- )
28
- for sublist in num_sentences
29
  ):
30
  nested_result = []
31
  start_idx = 0
@@ -38,11 +106,11 @@ def slice_embeddings(embeddings: NDArray, num_sentences: NumSentencesType) -> Em
38
  raise TypeError(f"Incorrect Type for {num_sentences=}")
39
 
40
 
41
- def is_list_of_strings_at_depth(obj, depth: int) -> bool:
42
  if depth == 0:
43
- return isinstance(obj, str)
44
  elif depth > 0:
45
- return isinstance(obj, list) and all(is_list_of_strings_at_depth(item, depth - 1) for item in obj)
46
  else:
47
  raise ValueError("Depth can't be negative")
48
 
 
 
1
  import statistics
2
  import sys
3
+ from dataclasses import dataclass
4
  from typing import List, Union
5
 
6
+ import torch
7
  from numpy.typing import NDArray
8
 
9
+ from type_aliases import DEVICE_TYPE, ENCODER_DEVICE_TYPE, NumSentencesType, EmbeddingSlicesType
10
 
11
+
12
+ def get_gpu(gpu: DEVICE_TYPE) -> ENCODER_DEVICE_TYPE:
13
+ """
14
+ Determine the correct GPU device based on the provided input. In the following, output 0 means CUDA device 0.
15
+
16
+ Args:
17
+ gpu (Union[bool, str, int, List[Union[str, int]]]): Input specifying the GPU device(s):
18
+ - bool: If True, returns 0 if CUDA is available, otherwise returns "cpu".
19
+ - str: Can be "cpu", "gpu", or "cuda" (case-insensitive). Returns 0 if CUDA is available
20
+ and the input is not "cpu", otherwise returns "cpu".
21
+ - int: Should be a valid GPU index. Returns the index if CUDA is available and valid,
22
+ otherwise returns "cpu".
23
+ - List[Union[str, int]]: List containing combinations of the str/int. Processes each
24
+ element and returns a list of corresponding results.
25
+
26
+ Returns:
27
+ Union[str, int, List[Union[str, int]]]: Depending on the input type:
28
+ - str: Returns "cpu" if no GPU is available or the input is "cpu".
29
+ - int: Returns the GPU index if valid and CUDA is available.
30
+ - List[Union[str, int]]: Returns a list of strings and/or integers based on the input list.
31
+
32
+ Raises:
33
+ ValueError: If the input gpu type is not recognized or invalid.
34
+ ValueError: If a string input is not one of ["cpu", "gpu", "cuda"].
35
+ ValueError: If an integer input is outside the valid range of GPU indices.
36
+
37
+ Notes:
38
+ - This function checks CUDA availability using torch.cuda.is_available() and counts
39
+ available GPUs using torch.cuda.device_count().
40
+ - Case insensitivity is maintained for string inputs ("cpu", "gpu", "cuda").
41
+ - The function ensures robust error handling for invalid input types or out-of-range indices.
42
+ """
43
+
44
+ # Ensure gpu index is within the range of total available gpus
45
+ gpu_available = torch.cuda.is_available()
46
+ gpu_count = torch.cuda.device_count()
47
+ correct_strs = ["cpu", "gpu", "cuda"]
48
+
49
+ def _get_single_device(gpu_item):
50
+ if isinstance(gpu_item, bool):
51
+ return 0 if gpu_item and gpu_available else "cpu"
52
+ elif isinstance(gpu_item, str):
53
+ if gpu_item.lower() not in correct_strs:
54
+ raise ValueError(f"Wrong gpu type: {gpu_item}. Should be one of {correct_strs}")
55
+ return 0 if (gpu_item.lower() != "cpu") and gpu_available else "cpu"
56
+ elif isinstance(gpu_item, int):
57
+ if gpu_item >= gpu_count:
58
+ raise ValueError(
59
+ f"There are {gpu_count} GPUs available. Provide a valid GPU index. You provided: {gpu_item}"
60
+ )
61
+ return gpu_item if gpu_available else "cpu"
62
+ else:
63
+ raise ValueError(f"Invalid gpu type: {type(gpu_item)}. Must be bool, str, or int.")
64
+
65
+ if isinstance(gpu, list):
66
+ seen_indices = set()
67
+ result = []
68
+ for item in gpu:
69
+ device = _get_single_device(item)
70
+ if isinstance(device, int):
71
+ if device not in seen_indices:
72
+ seen_indices.add(device)
73
+ result.append(device)
74
+ else:
75
+ result.append(device)
76
+ return result
77
+ else:
78
+ return _get_single_device(gpu)
79
 
80
 
81
  def slice_embeddings(embeddings: NDArray, num_sentences: NumSentencesType) -> EmbeddingSlicesType:
 
90
  result, _ = _slice_embeddings(0, num_sentences)
91
  return result
92
  elif isinstance(num_sentences, list) and all(
93
+ isinstance(sublist, list) and all(
94
+ isinstance(item, int) for item in sublist
95
+ )
96
+ for sublist in num_sentences
97
  ):
98
  nested_result = []
99
  start_idx = 0
 
106
  raise TypeError(f"Incorrect Type for {num_sentences=}")
107
 
108
 
109
+ def is_nested_list_of_type(lst_obj, element_type, depth: int) -> bool:
110
  if depth == 0:
111
+ return isinstance(lst_obj, element_type)
112
  elif depth > 0:
113
+ return isinstance(lst_obj, list) and all(is_nested_list_of_type(item, element_type, depth - 1) for item in lst_obj)
114
  else:
115
  raise ValueError("Depth can't be negative")
116