nbansal commited on
Commit
a54024a
1 Parent(s): 668c6f3

Added SemNCG metric

Browse files
Files changed (9) hide show
  1. .gitignore +1 -0
  2. README.md +85 -20
  3. __init__.py +0 -0
  4. encoder_models.py +129 -0
  5. requirements.txt +3 -1
  6. semncg.py +475 -45
  7. tests.py +418 -17
  8. type_aliases.py +11 -0
  9. utils.py +280 -0
.gitignore ADDED
@@ -0,0 +1 @@
 
 
1
+ __pycache__/
README.md CHANGED
@@ -5,46 +5,111 @@ datasets:
5
  tags:
6
  - evaluate
7
  - metric
8
- description: "TODO: add a description here"
 
 
 
 
9
  sdk: gradio
10
  sdk_version: 3.19.1
11
  app_file: app.py
12
  pinned: false
13
  ---
14
 
15
- # Metric Card for SemnCG
16
-
17
- ***Module Card Instructions:*** *Fill out the following subsections. Feel free to take a look at existing metric cards if you'd like examples.*
18
 
19
  ## Metric Description
20
- *Give a brief overview of this metric, including what task(s) it is usually used for, if any.*
 
 
 
 
21
 
22
  ## How to Use
23
- *Give general statement of how to use the metric*
24
 
25
- *Provide simplest possible example for using the metric*
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
26
 
27
- ### Inputs
28
- *List all input arguments in the format below*
29
- - **input_field** *(type): Definition of input, with explanation if necessary. State any default value(s).*
 
 
 
30
 
31
  ### Output Values
32
 
33
- *Explain what this metric outputs and provide an example of what the metric output looks like. Modules should return a dictionary with one or multiple key-value pairs, e.g. {"bleu" : 6.02}*
34
 
35
- *State the range of possible values that the metric's output can take, as well as what in that range is considered good. For example: "This metric can take on any value between 0 and 100, inclusive. Higher scores are better."*
 
36
 
37
- #### Values from Popular Papers
38
- *Give examples, preferrably with links to leaderboards or publications, to papers that have reported this metric, along with the values they have reported.*
39
 
40
- ### Examples
41
- *Give code examples of the metric being used. Try to include examples that clear up any potential ambiguity left from the metric description above. If possible, provide a range of examples that show both typical and atypical results, as well as examples where a variety of input parameters are passed.*
 
 
42
 
43
- ## Limitations and Bias
44
- *Note any known limitations or biases that the metric has, with links and references if possible.*
45
 
46
  ## Citation
47
- *Cite the source where this metric was introduced.*
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
48
 
49
  ## Further References
50
- *Add any useful further references.*
 
 
5
  tags:
6
  - evaluate
7
  - metric
8
+ description: "Sem-nCG (Semantic Normalized Cumulative Gain) Metric evaluates the quality of predicted sentences
9
+ (abstractive/extractive) in relation to reference sentences and documents using Semantic Normalized Cumulative Gain
10
+ (NCG). It computes gain values and NCG scores based on cosine similarity between sentence embeddings, leveraging a
11
+ Sentence-BERT encoder. This metric is designed to assess the relevance and ranking of predicted sentences, making it
12
+ useful for tasks such as summarization and information retrieval."
13
  sdk: gradio
14
  sdk_version: 3.19.1
15
  app_file: app.py
16
  pinned: false
17
  ---
18
 
19
+ # Metric Card for Sem-nCG
 
 
20
 
21
  ## Metric Description
22
+ Sem-nCG (Semantic Normalized Cumulative Gain) metric evaluates system-generated summaries (`predictions`) by comparing
23
+ them with ground truth reference summaries (`references`) and input documents (`documents`). It computes the Semantic
24
+ Normalized Cumulative Gain (NCG) scores based on sentence embeddings, which assess the quality of summaries by
25
+ evaluating the relevance of predicted sentences to the reference and input document sentences.
26
+
27
 
28
  ## How to Use
 
29
 
30
+ Before using this metric, you need to install the dependencies:
31
+ ```bash
32
+ pip install -U sentence-transformers nltk
33
+ ```
34
+
35
+ Sem-nCG takes three mandatory arguments:
36
+ - `predictions` - List of predictions
37
+ - `references` - List of references
38
+ - `documents` - List of input documents
39
+
40
+ ```python
41
+ from evaluate import load
42
+ predictions = [
43
+ "This is a prediction1 sentence 1. This is a prediction1 sentence 2.",
44
+ "This is a prediction2 sentence 1."
45
+ ]
46
+ references = [
47
+ "This is a reference1 sentence 1. This is a reference1 sentence 2.",
48
+ "This is a reference2 sentence 1. This is a reference2 sentence 2."
49
+ ]
50
+ documents = [
51
+ "This is a document1 sentence 1. This is a document1 sentence 2. This is a document1 sentence 3.",
52
+ "This is a document2 sentence 1. This is a document2 sentence 2."
53
+ ]
54
+ model_name = "all-MiniLM-L6-v2"
55
+ metric = load("nbansal/semncg", model_name=model_name) # model_name is optional. Default=all-MiniLM-L6-v2
56
+ mean_score, scores = metric.compute(predictions=predictions, references=references, documents=documents)
57
+ print(f"Mean SemnCG: {mean_score}")
58
+ ```
59
+
60
+ Sem-nCG also accepts several optional arguments:
61
+ - `tokenize_sentences (bool)`: Flag to indicate whether to tokenize the sentences in the input documents. Default: True
62
+ - `pre_compute_embeddings (bool)`: Flag to indicate whether to pre-compute embeddings for all sentences. Default=False
63
+ - `k (int)`: The rank threshold used for evaluating gains (typically top-k sentences). Default is 3.
64
+ - `gpu (Union[bool, str, int, List[Union[str, int]]])`: Whether to use GPU, CPU, or multiple processes for computation.
65
+ - `batch_size (int)`: Batch size for encoding. Default is 32.
66
+ - `verbose (bool)`: Flag to indicate verbose output. Default is False.
67
+ - `debug (bool)`: Flag to return detailed debug information including ranked gains. Default is False.
68
 
69
+ Refer to the inputs descriptions for more detailed usage as follows:
70
+ ```python
71
+ import evaluate
72
+ metric = evaluate.load("nbansal/semncg")
73
+ print(metric.inputs_description)
74
+ ```
75
 
76
  ### Output Values
77
 
78
+ The output is a tuple containing:
79
 
80
+ Mean Sem-nCG score: float: The average Sem-nCG score.
81
+ scores: List[Union[float, RankedGains]]: List of Sem-nCG scores or RankedGains objects for each document.
82
 
 
 
83
 
84
+ ## Extensions
85
+ The current implementation supports any model from Huggingface/SentenceTransformer that is compatible with
86
+ SentenceTransformer, such as `all-mpnet-base-v2` or `roberta-base`. You can extend the metric with more models by
87
+ extending the `Encoder` base class in the `encoder_models.py` file.
88
 
89
+ ## Deviations from Published Methodology
 
90
 
91
  ## Citation
92
+ ```bibtex
93
+ @inproceedings{akter-etal-2022-revisiting,
94
+ title = "Revisiting Automatic Evaluation of Extractive Summarization Task: Can We Do Better than {ROUGE}?",
95
+ author = "Akter, Mousumi and
96
+ Bansal, Naman and
97
+ Karmaker, Shubhra Kanti",
98
+ editor = "Muresan, Smaranda and
99
+ Nakov, Preslav and
100
+ Villavicencio, Aline",
101
+ booktitle = "Findings of the Association for Computational Linguistics: ACL 2022",
102
+ month = may,
103
+ year = "2022",
104
+ address = "Dublin, Ireland",
105
+ publisher = "Association for Computational Linguistics",
106
+ url = "https://aclanthology.org/2022.findings-acl.122",
107
+ doi = "10.18653/v1/2022.findings-acl.122",
108
+ pages = "1547--1560",
109
+ abstract = "It has been the norm for a long time to evaluate automated summarization tasks using the popular ROUGE metric. Although several studies in the past have highlighted the limitations of ROUGE, researchers have struggled to reach a consensus on a better alternative until today. One major limitation of the traditional ROUGE metric is the lack of semantic understanding (relies on direct overlap of n-grams). In this paper, we exclusively focus on the extractive summarization task and propose a semantic-aware nCG (normalized cumulative gain)-based evaluation metric (called Sem-nCG) for evaluating this task. One fundamental contribution of the paper is that it demonstrates how we can generate more reliable semantic-aware ground truths for evaluating extractive summarization tasks without any additional human intervention. To the best of our knowledge, this work is the first of its kind. We have conducted extensive experiments with this new metric using the widely used CNN/DailyMail dataset. Experimental results show that the new Sem-nCG metric is indeed semantic-aware, shows higher correlation with human judgement (more reliable) and yields a large number of disagreements with the original ROUGE metric (suggesting that ROUGE often leads to inaccurate conclusions also verified by humans).",
110
+ }
111
+ ```
112
 
113
  ## Further References
114
+ - [Paper](https://aclanthology.org/2022.findings-acl.122/)
115
+ - [Video](https://underline.io/lecture/50182-findings-revisiting-automatic-evaluation-of-extractive-summarization-task-can-we-do-better-than-rougequestion)
__init__.py ADDED
File without changes
encoder_models.py ADDED
@@ -0,0 +1,129 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import abc
2
+ from typing import List, Union
3
+
4
+ from numpy.typing import NDArray
5
+ from sentence_transformers import SentenceTransformer
6
+
7
+ from .type_aliases import ENCODER_DEVICE_TYPE
8
+
9
+
10
+ class Encoder(abc.ABC):
11
+ @abc.abstractmethod
12
+ def encode(self, prediction: List[str]) -> NDArray:
13
+ """
14
+ Abstract method to encode a list of sentences into sentence embeddings.
15
+
16
+ Args:
17
+ prediction (List[str]): List of sentences to encode.
18
+
19
+ Returns:
20
+ NDArray: Array of sentence embeddings with shape (num_sentences, embedding_dim).
21
+
22
+ Raises:
23
+ NotImplementedError: If the method is not implemented in the subclass.
24
+ """
25
+ raise NotImplementedError("Method 'encode' must be implemented in subclass.")
26
+
27
+
28
+ class SBertEncoder(Encoder):
29
+ def __init__(self, model: SentenceTransformer, device: ENCODER_DEVICE_TYPE, batch_size: int, verbose: bool):
30
+ """
31
+ Initialize SBertEncoder instance.
32
+
33
+ Args:
34
+ model (SentenceTransformer): The Sentence Transformer model instance to use for encoding.
35
+ device (Union[str, int, List[Union[str, int]]]): Device specification for encoding
36
+ batch_size (int): Batch size for encoding.
37
+ verbose (bool): Whether to print verbose information during encoding.
38
+ """
39
+ self.model = model
40
+ self.device = device
41
+ self.batch_size = batch_size
42
+ self.verbose = verbose
43
+
44
+ def encode(self, prediction: List[str]) -> NDArray:
45
+ """
46
+ Encode a list of sentences into sentence embeddings.
47
+
48
+ Args:
49
+ prediction (List[str]): List of sentences to encode.
50
+
51
+ Returns:
52
+ NDArray: Array of sentence embeddings with shape (num_sentences, embedding_dim).
53
+ """
54
+
55
+ # SBert output is always Batch x Dim
56
+ if isinstance(self.device, list):
57
+ # Use multiprocess encoding for list of devices
58
+ pool = self.model.start_multi_process_pool(target_devices=self.device)
59
+ embeddings = self.model.encode_multi_process(prediction, pool=pool, batch_size=self.batch_size)
60
+ self.model.stop_multi_process_pool(pool)
61
+ else:
62
+ # Single device encoding
63
+ embeddings = self.model.encode(
64
+ prediction,
65
+ device=self.device,
66
+ batch_size=self.batch_size,
67
+ show_progress_bar=self.verbose,
68
+ )
69
+
70
+ return embeddings
71
+
72
+
73
+ def get_encoder(
74
+ sbert_model: SentenceTransformer,
75
+ device: ENCODER_DEVICE_TYPE,
76
+ batch_size: int,
77
+ verbose: bool,
78
+ ) -> Encoder:
79
+ """
80
+ Get an instance of SBertEncoder using the provided parameters.
81
+
82
+ Args:
83
+ sbert_model (SentenceTransformer): An instance of SentenceTransformer model to use for encoding.
84
+ device (Union[str, int, List[Union[str, int]]): Device specification for the encoder
85
+ (e.g., "cuda", 0 for GPU, "cpu").
86
+ batch_size (int): Batch size to use for encoding.
87
+ verbose (bool): Whether to print verbose information during encoding.
88
+
89
+ Returns:
90
+ SBertEncoder: Instance of the selected encoder based on the model_name.
91
+
92
+ Example:
93
+ >>> model_name = "paraphrase-distilroberta-base-v1"
94
+ >>> sbert_model = get_sbert_encoder(model_name)
95
+ >>> device = get_gpu("cuda")
96
+ >>> batch_size = 32
97
+ >>> verbose = True
98
+ >>> encoder = get_encoder(sbert_model, device, batch_size, verbose)
99
+ """
100
+ encoder = SBertEncoder(sbert_model, device, batch_size, verbose)
101
+ return encoder
102
+
103
+
104
+ def get_sbert_encoder(model_name: str) -> SentenceTransformer:
105
+ """
106
+ Get an instance of SentenceTransformer encoder based on the specified model name.
107
+
108
+ Args:
109
+ model_name (str): Name of the model to instantiate. You can use any model on Huggingface/SentenceTransformer
110
+ that is supported by SentenceTransformer.
111
+
112
+ Returns:
113
+ SentenceTransformer: Instance of the selected encoder based on the model_name.
114
+
115
+ Raises:
116
+ EnvironmentError: If an unsupported model_name is provided.
117
+ RuntimeError: If there's an issue during instantiation of the encoder.
118
+ """
119
+
120
+ try:
121
+ encoder = SentenceTransformer(model_name, trust_remote_code=True)
122
+ except EnvironmentError as err:
123
+ raise EnvironmentError(str(err)) from None
124
+ except Exception as err:
125
+ raise RuntimeError(str(err)) from None
126
+
127
+ return encoder
128
+
129
+
requirements.txt CHANGED
@@ -1 +1,3 @@
1
- git+https://github.com/huggingface/evaluate@main
 
 
 
1
+ git+https://github.com/huggingface/evaluate@main
2
+ nltk
3
+ sentence-transformers
semncg.py CHANGED
@@ -11,55 +11,340 @@
11
  # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
  # See the License for the specific language governing permissions and
13
  # limitations under the License.
14
- """TODO: Add a description here."""
15
 
 
16
  import evaluate
17
  import datasets
 
 
 
18
 
 
 
 
 
 
 
 
 
19
 
20
- # TODO: Add BibTeX citation
21
  _CITATION = """\
22
- @InProceedings{huggingface:module,
23
- title = {A great new module},
24
- authors={huggingface, Inc.},
25
- year={2020}
 
 
 
 
 
 
 
 
 
 
 
 
 
26
  }
27
  """
28
 
29
- # TODO: Add description of the module here
30
  _DESCRIPTION = """\
31
- This new module is designed to solve this great ML task and is crafted with a lot of care.
 
 
 
 
32
  """
33
 
34
-
35
- # TODO: Add description of the arguments of the module here
36
  _KWARGS_DESCRIPTION = """
37
- Calculates how good are predictions given some references, using certain scores
 
 
 
38
  Args:
39
- predictions: list of predictions to score. Each predictions
40
- should be a string with tokens separated by spaces.
41
- references: list of reference for each prediction. Each
42
- reference should be a string with tokens separated by spaces.
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
43
  Returns:
44
- accuracy: description of the first score,
45
- another_score: description of the second score,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
46
  Examples:
47
- Examples should be written in doctest format, and should illustrate how
48
- to use the function.
49
 
50
- >>> my_new_module = evaluate.load("my_new_module")
51
- >>> results = my_new_module.compute(references=[0, 1], predictions=[0, 1])
52
- >>> print(results)
53
- {'accuracy': 1.0}
 
 
 
54
  """
55
 
56
- # TODO: Define external resources urls if needed
57
- BAD_WORDS_URL = "http://url/to/external/resource/bad_words.txt"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
58
 
59
 
60
  @evaluate.utils.file_utils.add_start_docstrings(_DESCRIPTION, _KWARGS_DESCRIPTION)
61
  class SemnCG(evaluate.Metric):
62
- """TODO: Short description of my evaluation module."""
 
 
 
 
 
 
 
 
 
 
63
 
64
  def _info(self):
65
  # TODO: Specifies the evaluate.EvaluationModuleInfo object
@@ -70,26 +355,171 @@ class SemnCG(evaluate.Metric):
70
  citation=_CITATION,
71
  inputs_description=_KWARGS_DESCRIPTION,
72
  # This defines the format of each prediction and reference
73
- features=datasets.Features({
74
- 'predictions': datasets.Value('int64'),
75
- 'references': datasets.Value('int64'),
76
- }),
77
- # Homepage of the module for documentation
78
- homepage="http://module.homepage",
79
- # Additional links to the codebase or references
80
- codebase_urls=["http://github.com/path/to/codebase/of/new_module"],
81
- reference_urls=["http://path.to.reference.url/new_module"]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
82
  )
83
 
84
  def _download_and_prepare(self, dl_manager):
85
  """Optional: download external resources useful to compute the scores"""
86
- # TODO: Download external resources if needed
87
- pass
88
-
89
- def _compute(self, predictions, references):
90
- """Returns the scores"""
91
- # TODO: Compute the different scores of the module
92
- accuracy = sum(i == j for i, j in zip(predictions, references)) / len(predictions)
93
- return {
94
- "accuracy": accuracy,
95
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
11
  # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
  # See the License for the specific language governing permissions and
13
  # limitations under the License.
14
+ """Sem-NCG metric"""
15
 
16
+ from dataclasses import dataclass
17
  import evaluate
18
  import datasets
19
+ import re
20
+ import statistics
21
+ from typing import Dict, List, Tuple, Union
22
 
23
+ import nltk
24
+ import numpy as np
25
+ from sklearn.metrics.pairwise import cosine_similarity
26
+ from tqdm import tqdm
27
+
28
+ from .encoder_models import get_sbert_encoder, get_encoder
29
+ from .type_aliases import DEVICE_TYPE, NDArray, DOCUMENT_TYPE
30
+ from .utils import get_gpu, prep_sentences, flatten_list, slice_embeddings, is_nested_list_of_type, tokenize_and_prep_document
31
 
 
32
  _CITATION = """\
33
+ @inproceedings{akter-etal-2022-revisiting,
34
+ title = "Revisiting Automatic Evaluation of Extractive Summarization Task: Can We Do Better than {ROUGE}?",
35
+ author = "Akter, Mousumi and
36
+ Bansal, Naman and
37
+ Karmaker, Shubhra Kanti",
38
+ editor = "Muresan, Smaranda and
39
+ Nakov, Preslav and
40
+ Villavicencio, Aline",
41
+ booktitle = "Findings of the Association for Computational Linguistics: ACL 2022",
42
+ month = may,
43
+ year = "2022",
44
+ address = "Dublin, Ireland",
45
+ publisher = "Association for Computational Linguistics",
46
+ url = "https://aclanthology.org/2022.findings-acl.122",
47
+ doi = "10.18653/v1/2022.findings-acl.122",
48
+ pages = "1547--1560",
49
+ abstract = "It has been the norm for a long time to evaluate automated summarization tasks using the popular ROUGE metric. Although several studies in the past have highlighted the limitations of ROUGE, researchers have struggled to reach a consensus on a better alternative until today. One major limitation of the traditional ROUGE metric is the lack of semantic understanding (relies on direct overlap of n-grams). In this paper, we exclusively focus on the extractive summarization task and propose a semantic-aware nCG (normalized cumulative gain)-based evaluation metric (called Sem-nCG) for evaluating this task. One fundamental contribution of the paper is that it demonstrates how we can generate more reliable semantic-aware ground truths for evaluating extractive summarization tasks without any additional human intervention. To the best of our knowledge, this work is the first of its kind. We have conducted extensive experiments with this new metric using the widely used CNN/DailyMail dataset. Experimental results show that the new Sem-nCG metric is indeed semantic-aware, shows higher correlation with human judgement (more reliable) and yields a large number of disagreements with the original ROUGE metric (suggesting that ROUGE often leads to inaccurate conclusions also verified by humans).",
50
  }
51
  """
52
 
 
53
  _DESCRIPTION = """\
54
+ Sem-nCG (Semantic Normalized Cumulative Gain) Metric evaluates the quality of predicted sentences
55
+ (abstractive/extractive) in relation to reference sentences and documents using Semantic Normalized Cumulative Gain
56
+ (NCG). It computes gain values and NCG scores based on cosine similarity between sentence embeddings, leveraging a
57
+ Sentence-BERT encoder. This metric is designed to assess the relevance and ranking of predicted sentences, making it
58
+ useful for tasks such as summarization and information retrieval.
59
  """
60
 
 
 
61
  _KWARGS_DESCRIPTION = """
62
+ Sem-nCG (Semantic Normalized Cumulative Gain) compares the system-generated summaries (predictions) with ground truth
63
+ reference summaries (references) and input documents (documents) using Semantic Normalized Cumulative Gain (NCG).
64
+ It computes gain values and NCG scores based on sentence embeddings.
65
+
66
  Args:
67
+ predictions (DOCUMENT_TYPE): The predicted sentences.
68
+ `tokenize_sentences`=True -> predictions: List[str]
69
+ `tokenize_sentences`=False -> predictions: List[List[str]]
70
+ references (DOCUMENT_TYPE): The reference sentences.
71
+ `tokenize_sentences`=True -> references: List[str]
72
+ `tokenize_sentences`=False -> references: List[List[str]]
73
+ documents (DOCUMENT_TYPE): Input documents.
74
+ `tokenize_sentences`=True -> documents: List[str]
75
+ `tokenize_sentences`=False -> documents: List[List[str]]
76
+ k (int): The rank threshold used for evaluating gains (typically top-k sentences). Default is 3.
77
+ gpu (Union[bool, str, int, List[Union[str, int]]]): Whether to use GPU or CPU for computation.
78
+ bool -
79
+ False - CPU (Default)
80
+ True - GPU (device 0) if gpu is available else CPU
81
+ int -
82
+ n - GPU, device index n
83
+ str -
84
+ 'cuda', 'gpu', 'cpu'
85
+ List[Union[str, int]] - Multiple GPUs/cpus i.e. use multiple processes when computing embeddings
86
+ batch_size (int): Batch size for encoding. Default is 32.
87
+ verbose (bool): Flag to indicate verbose output. Default is False.
88
+ tokenize_sentences (bool): Flag to indicate whether to tokenize the sentences in the input documents. Default: True.
89
+ pre_compute_embeddings (bool): Flag to indicate whether to pre-compute embeddings for all sentences. This speeds up
90
+ computation but requires more memory. Default is False.
91
+ debug (bool): Flag to return detailed debug information including ranked gains. Default is False.
92
+
93
  Returns:
94
+ Union[Tuple[float, List[float]], Tuple[float, List[RankedGains]]]:
95
+ If `debug` is False, returns a tuple containing the mean SemnCG score and a list of SemnCG scores for each document.
96
+ If `debug` is True, returns a tuple containing the mean SemnCG score and a list of `RankedGains` objects with
97
+ detailed gain information for each document.
98
+
99
+ Examples of input formats:
100
+
101
+ Case 1: tokenize_sentences = True
102
+ predictions: List[str] - List of predictions where each prediction is a document.
103
+ references: List[str] - List of references where each reference is a document.
104
+ documents: List[str] - List of input documents where each document is a document.
105
+ Example:
106
+ predictions = ["This is a prediction sentence 1. This is a prediction sentence 2."]
107
+ references = ["This is a reference sentence 1. This is a reference sentence 2."]
108
+ documents = ["This is a document sentence 1. This is a document sentence 2."]
109
+
110
+ Case 2: tokenize_sentences = False
111
+ predictions: List[List[str]] - List of predictions where each prediction is a list of sentences.
112
+ references: List[List[str]] - List of references where each reference is a list of sentences.
113
+ documents: List[List[str]] - List of input documents where each document is a list of sentences.
114
+ Example:
115
+ predictions = [["This is a prediction sentence 1.", "This is a prediction sentence 2."]]
116
+ references = [["This is a reference sentence 1.", "This is a reference sentence 2."]]
117
+ documents = [["This is a document sentence 1.", "This is a document sentence 2."]]
118
+
119
  Examples:
 
 
120
 
121
+ >>> import evaluate
122
+ >>> predictions = ["This is a prediction sentence 1. This is a prediction sentence 2."]
123
+ >>> references = ["This is a reference sentence 1. This is a reference sentence 2."]
124
+ >>> documents = ["This is a document sentence 1. This is a document sentence 2."]
125
+ >>> metric = evaluate.load("nbansal/semncg", model_name="all-MiniLM-L6-v2")
126
+ >>> mean_score, scores = metric.compute(predictions=predictions, references=references, documents=documents)
127
+ >>> print(f"Mean SemnCG: {mean_score}")
128
  """
129
 
130
+
131
+
132
+
133
+ @dataclass
134
+ class RankedGains:
135
+ """
136
+ Dataclass to store ranked gains and associated metadata.
137
+
138
+ Attributes:
139
+ gt_gains (List[Tuple[str, float]]): List of tuples representing ground truth (ideal) gains,
140
+ where each tuple contains a document sentence and its corresponding gain value.
141
+ pred_gains (List[Tuple[str, float]]): List of tuples representing predicted gains by the model,
142
+ where each tuple contains a document identifier and its corresponding gain value.
143
+ k (int): The rank threshold used for evaluating gains (typically top-k documents).
144
+ ncg (float): Normalized Cumulative Gain (NCG) score calculated based on the predicted gains
145
+ compared to the ground truth gains.
146
+
147
+ Notes:
148
+ - `gt_gains` and `pred_gains` are typically sorted in descending order
149
+ - `k` specifies the top-k threshold used for evaluating the gains.
150
+ - `ncg` provides a normalized measure of the model's performance.
151
+ """
152
+ gt_gains: List[Tuple[str, float]]
153
+ pred_gains: List[Tuple[str, float]]
154
+ k: int
155
+ ncg: float
156
+
157
+
158
+ def compute_cosine_similarity(doc_embeds: NDArray, ref_embeds: NDArray) -> List[float]:
159
+ """
160
+ Compute cosine similarity scores between each document embedding and reference embeddings.
161
+
162
+ Args:
163
+ doc_embeds (NDArray): 2D array of shape (#Docs, Embedding_dim) containing document embeddings.
164
+ ref_embeds (NDArray): 2D array of shape (#Refs, Embedding_dim) containing reference embeddings.
165
+
166
+ Returns:
167
+ List[float]: A list of mean cosine similarity scores between each document and reference embeddings.
168
+ The length of the list is equal to the number of documents (#Docs).
169
+
170
+ Notes:
171
+ - Uses cosine_similarity function from sklearn.metrics.pairwise to compute pairwise cosine similarities.
172
+ - Returns the mean cosine similarity scores across reference embeddings for each document embedding.
173
+ """
174
+ # Compute cosine similarity between predicted and reference embeddings
175
+ cosine_scores = cosine_similarity(doc_embeds, ref_embeds) # [#Docs, #Refs]
176
+ return np.mean(cosine_scores, axis=1).tolist()
177
+
178
+
179
+ def compute_gain(sim_scores: List[float]) -> List[Tuple[int, float]]:
180
+ """
181
+ Compute gain values for ranked similarity scores.
182
+
183
+ Args:
184
+ sim_scores (List[float]): List of similarity scores for documents (`compute_cosine_similarity(doc_embeds, ref_embeds)`)
185
+
186
+ Returns:
187
+ List[Tuple[int, float]]: A list of tuples where each tuple contains a document index and its corresponding gain
188
+ value. The list is sorted by descending order of gain values.
189
+
190
+ Notes:
191
+ - Computes gain values based on the rank order of similarity scores, where higher scores indicate higher gains.
192
+ - Uses the formula: gain = rank_position / sum of ranks, where rank_position starts from 1 for the highest score
193
+ - Returns a list sorted by descending gain values.
194
+ """
195
+ count = len(sim_scores)
196
+ sim_scores = np.array(sim_scores).argsort()[::-1] # Reverse Sorted Order of doc sentence indices
197
+ denominator = count * (count + 1) / 2 # (n * (n+1))/2
198
+ return [(s_idx, val / denominator) for s_idx, val in zip(sim_scores, range(count, 0, -1))]
199
+
200
+
201
+ def score_ncg(model_relevance: List[float], gt_relevance: List[float]) -> float:
202
+ """
203
+ Calculate the Normalized Cumulative Gain (NCG) score based on model relevance and ground truth relevance.
204
+
205
+ Args:
206
+ model_relevance (List[float]): List of gain values representing the relevance scores predicted by the model.
207
+ gt_relevance (List[float]): List of gain values representing the ground truth (ideal) relevance scores.
208
+
209
+ Returns:
210
+ float: Normalized Cumulative Gain (NCG) score, which measures the effectiveness of the model's relevance
211
+ predictions compared to the ideal relevance scores. The score ranges from 0 to 1, where higher values
212
+ indicate better performance.
213
+
214
+ Notes:
215
+ - Calculates Cumulative Gain (CG) for both model and ground truth relevance lists.
216
+ - Normalizes CG scores by dividing model CG by ground truth CG to get the NCG score.
217
+ - Returns 0 if the ground truth CG (icg) is 0 to avoid division by zero.
218
+ """
219
+
220
+ # CG score
221
+ cg = sum(model_relevance)
222
+
223
+ # ICG score
224
+ icg = sum(gt_relevance)
225
+
226
+ # Normalized CG score
227
+ return cg / icg if icg != 0 else 0
228
+
229
+
230
+ def compute_ncg(pred_gains: List[Tuple[int, float]], gt_gains: List[Tuple[int, float]], k: int) -> float:
231
+ """
232
+ Compute the Normalized Cumulative Gain (NCG) score based on predicted and ground truth gains up to rank k.
233
+
234
+ Args:
235
+ pred_gains (List[Tuple[int, float]]): List of tuples representing predicted gains by the model,
236
+ where each tuple contains a document position (or index) and its corresponding gain value.
237
+ (Sorted in Descending Order)
238
+ gt_gains (List[Tuple[int, float]]): List of tuples representing ground truth gains (ideal gains),
239
+ where each tuple contains a document position (or index) and its corresponding gain value.
240
+ (Sorted in Descending Order)
241
+ k (int): The rank threshold used for evaluating gains (typically top-k documents).
242
+
243
+ Returns:
244
+ float: Normalized Cumulative Gain (NCG) score based on the predicted gains compared to the ground truth gains.
245
+
246
+ Notes:
247
+ - Both `pred_gains` and `gt_gains` should be sorted lists (in descending order) where higher gain values indicate
248
+ higher relevance.
249
+ - The function calculates NCG up to rank `k`, considering only the top-k documents.
250
+ - Uses the `score_ncg` function to compute the NCG score based on the model's predicted gains and the ground
251
+ truth.
252
+ """
253
+ gt_dict = dict(gt_gains)
254
+ gt_rel = [v for _, v in gt_gains[:k]]
255
+ model_rel = [gt_dict[position] for position, _ in pred_gains[:k]]
256
+ return score_ncg(model_rel, gt_rel)
257
+
258
+
259
+ def _validate_input_format(
260
+ tokenize_sentences: bool,
261
+ predictions: DOCUMENT_TYPE,
262
+ references: DOCUMENT_TYPE,
263
+ documents: DOCUMENT_TYPE
264
+ ):
265
+ """
266
+ Validate the format of predictions, references, and documents based on specified criteria.
267
+
268
+ Args:
269
+ tokenize_sentences (bool): Flag indicating whether sentences should be tokenized.
270
+ predictions (DOCUMENT_TYPE): Predictions to validate.
271
+ references (DOCUMENT_TYPE): References to validate.
272
+ documents (DOCUMENT_TYPE): Documents to validate.
273
+
274
+ Raises:
275
+ ValueError: If the format of predictions, references, or documents does not meet the specified criteria.
276
+
277
+ Validation Criteria:
278
+ The function validates predictions, references, and documents based on the following conditions:
279
+ 1. If `tokenize_sentences` is True:
280
+ - Predictions, references, and documents must all be lists of strings (`is_list_of_strings_at_depth(obj, 1)`).
281
+
282
+ 2. If `tokenize_sentences` is False:
283
+ - Predictions, references, and documents must all be lists of lists of strings
284
+ (`is_list_of_strings_at_depth(obj, 2)`).
285
+
286
+ The function checks these conditions and raises a ValueError if any condition is not met,
287
+ indicating that predictions, references, or documents are not in the valid input format.
288
+
289
+ Notes:
290
+ - `DOCUMENT_TYPE`: Union[List[str], List[List[str]]]
291
+ - Uses helper function `is_list_of_strings_at_depth` to validate the format of lists of strings.
292
+
293
+ Example:
294
+ >>> tokenize_sentences = True
295
+ >>> predictions = ["This is prediction 1.", "This is prediction 2."]
296
+ >>> references = ["Reference for prediction 1.", "Reference for prediction 2."]
297
+ >>> documents = ["Document 1 content.", "Document 2 content."]
298
+ >>> _validate_input_format(tokenize_sentences, predictions, references, documents)
299
+
300
+ Example:
301
+ >>> tokenize_sentences = False
302
+ >>> predictions = [["Sentence 1 in prediction 1.", "Sentence 2 in prediction 1."],
303
+ >>> ["Sentence 1 in prediction 2.", "Sentence 2 in prediction 2."]]
304
+ >>> references = [["Sentences in reference 1."], ["Sentences in reference 2."]]
305
+ >>> documents = [["Sentence 1 in document 1.", "Sentence 2 in document 1."],
306
+ >>> ["Sentence 1 in document 2.", "Sentence 2 in document 2."]]
307
+ >>> _validate_input_format(tokenize_sentences, predictions, references, documents)
308
+ """
309
+ if not (len(predictions) == len(references) == len(documents)):
310
+ raise ValueError("Predictions, References and Documents must have the same length.")
311
+
312
+ if len(predictions) == 0:
313
+ raise ValueError("Can't have empty inputs")
314
+
315
+ def is_list_of_strings_at_depth(lst_obj, depth: int):
316
+ return is_nested_list_of_type(lst_obj, element_type=str, depth=depth)
317
+
318
+ if tokenize_sentences:
319
+ condition = (
320
+ is_list_of_strings_at_depth(predictions, 1) and
321
+ is_list_of_strings_at_depth(references, 1) and
322
+ is_list_of_strings_at_depth(documents, 1)
323
+ )
324
+ else:
325
+ condition = (
326
+ is_list_of_strings_at_depth(predictions, 2) and
327
+ is_list_of_strings_at_depth(references, 2) and
328
+ is_list_of_strings_at_depth(documents, 2)
329
+ )
330
+
331
+ if not condition:
332
+ raise ValueError("Predictions, References and Documents are not valid input format. Refer to documentation.")
333
 
334
 
335
  @evaluate.utils.file_utils.add_start_docstrings(_DESCRIPTION, _KWARGS_DESCRIPTION)
336
  class SemnCG(evaluate.Metric):
337
+ """
338
+ SemnCG (Semantic Normalized Cumulative Gain) Metric.
339
+
340
+ This metric evaluates the quality of predicted sentences in relation to reference sentences and documents
341
+ using Semantic Normalized Cumulative Gain (NCG). It computes the gain values and NCG scores based on
342
+ cosine similarity between sentence embeddings, leveraging a Sentence-BERT encoder.
343
+ """
344
+
345
+ def __init__(self, model_name: str = "all-MiniLM-L6-v2", **kwargs):
346
+ self.sbert_encoder = get_sbert_encoder(model_name)
347
+ super().__init__(**kwargs)
348
 
349
  def _info(self):
350
  # TODO: Specifies the evaluate.EvaluationModuleInfo object
 
355
  citation=_CITATION,
356
  inputs_description=_KWARGS_DESCRIPTION,
357
  # This defines the format of each prediction and reference
358
+ features=[
359
+ # Tokenize_Sentences = True
360
+ datasets.Features(
361
+ {
362
+ "predictions": datasets.Value("string"),
363
+ "references": datasets.Value("string"),
364
+ "documents": datasets.Value("string"),
365
+ }
366
+ ),
367
+ # Tokenize_Sentences = False
368
+ datasets.Features(
369
+ {
370
+ "predictions": datasets.Sequence(datasets.Value("string", id="sequence"), id="predictions"),
371
+ "references": datasets.Sequence(datasets.Value("string", id="sequence"), id="references"),
372
+ "documents": datasets.Sequence(datasets.Value("string", id="sequence"), id="documents"),
373
+ }
374
+ ),
375
+ ],
376
+ # # Homepage of the module for documentation
377
+ # homepage="http://module.homepage",
378
+ # # Additional links to the codebase or references
379
+ # codebase_urls=["http://github.com/path/to/codebase/of/new_module"],
380
+ reference_urls=["https://aclanthology.org/2022.findings-acl.122/"]
381
  )
382
 
383
  def _download_and_prepare(self, dl_manager):
384
  """Optional: download external resources useful to compute the scores"""
385
+ nltk.download("punkt", quiet=True)
386
+
387
+ def _compute(
388
+ self,
389
+ predictions: DOCUMENT_TYPE,
390
+ references: DOCUMENT_TYPE,
391
+ documents: DOCUMENT_TYPE,
392
+ k: int = 3,
393
+ gpu: DEVICE_TYPE = False,
394
+ verbose: bool = False,
395
+ batch_size: int = 32,
396
+ tokenize_sentences: bool = True,
397
+ pre_compute_embeddings: bool = False,
398
+ debug: bool = False,
399
+ ) -> Union[Tuple[float, List[float]], Tuple[float, List[RankedGains]]]:
400
+ """
401
+ Compute the Semantic Normalized Cumulative Gain (SemnCG) score.
402
+
403
+ Args:
404
+ predictions (DOCUMENT_TYPE): The predicted sentences.
405
+ `tokenize_sentences`=True -> predictions: List[str]
406
+ `tokenize_sentences`=False -> predictions: List[List[str]]
407
+ references (DOCUMENT_TYPE): The reference sentences.
408
+ `tokenize_sentences`=True -> references: List[str]
409
+ `tokenize_sentences`=False -> references: List[List[str]]
410
+ documents (DOCUMENT_TYPE): Input documents.
411
+ `tokenize_sentences`=True -> references: List[str]
412
+ `tokenize_sentences`=False -> references: List[List[str]]
413
+ k (int, optional): The rank threshold used for evaluating gains (typically top-k sentences). Default is 3.
414
+ gpu (DEVICE_TYPE, optional): Whether to use GPU for computation. Default is False.
415
+ verbose (bool, optional): Whether to print verbose logs. Default is False.
416
+ batch_size (int, optional): The batch size for encoding sentences. Default is 32.
417
+ tokenize_sentences (bool, optional): Whether to tokenize sentences. If True, sentences are tokenized before
418
+ processing. Default is True.
419
+ pre_compute_embeddings (bool, optional): Whether to pre-compute embeddings for all sentences. This speeds up
420
+ computation but requires more memory. Default is False.
421
+ debug (bool, optional): Whether to return detailed debug information including ranked gains. Default=False.
422
+
423
+ Returns:
424
+ Union[Tuple[float, List[float]], Tuple[float, List[RankedGains]]]:
425
+ If `debug` is False, returns a tuple containing the mean SemnCG score and a list of SemnCG scores for each document.
426
+ If `debug` is True, returns a tuple containing the mean SemnCG score and a list of `RankedGains` objects with detailed gain information for each document.
427
+
428
+ Raises:
429
+ ValueError: If the format of predictions, references, or documents does not meet the specified criteria.
430
+
431
+ Notes:
432
+ - Validates the format of predictions, references, and documents based on `tokenize_sentences`.
433
+ - Computes embeddings using a Sentence-BERT encoder.
434
+ - Computes cosine similarity between document, reference, and prediction embeddings.
435
+ - Calculates gain values and Normalized Cumulative Gain (NCG) scores.
436
+ - Optionally returns detailed debug information for each document if `debug` is True.
437
+ """
438
+
439
+ # Validate inputs corresponding to flags
440
+ _validate_input_format(tokenize_sentences, predictions, references, documents)
441
+
442
+ # Get GPU
443
+ device = get_gpu(gpu)
444
+ if verbose:
445
+ print(f"Using devices: {device}")
446
+
447
+ # Get model
448
+ encoder = get_encoder(self.sbert_encoder, device=device, batch_size=batch_size, verbose=verbose)
449
+
450
+ if pre_compute_embeddings: # fast but takes more memory
451
+ predictions = [tokenize_and_prep_document(pred, tokenize_sentences) for pred in predictions]
452
+ references = [tokenize_and_prep_document(ref, tokenize_sentences) for ref in references]
453
+ documents = [tokenize_and_prep_document(doc, tokenize_sentences) for doc in documents]
454
+
455
+ # This is only done for debug case
456
+ sent_tokenized_documents = documents
457
+
458
+ # Compute All Embeddings
459
+ all_sentences = flatten_list(documents) + flatten_list(references) + flatten_list(predictions)
460
+ embeddings = encoder.encode(all_sentences)
461
+
462
+ prediction_sentences_count = [len(pred) for pred in predictions]
463
+ reference_sentences_count = [len(ref) for ref in references]
464
+ document_sentences_count = [len(doc) for doc in documents]
465
+
466
+ # Get embeddings corresponding to documents, references and predictions (IN ORDER)
467
+ doc_embeddings = slice_embeddings(embeddings, document_sentences_count)
468
+ ref_embeddings = slice_embeddings(embeddings[sum(document_sentences_count):], reference_sentences_count)
469
+ pred_embeddings = slice_embeddings(
470
+ embeddings[sum(document_sentences_count+reference_sentences_count):], prediction_sentences_count
471
+ )
472
+
473
+ iterable_obj = zip(pred_embeddings, ref_embeddings, doc_embeddings)
474
+
475
+ else:
476
+ iterable_obj = zip(predictions, references, documents)
477
+
478
+ out = []
479
+ for idx, (pred, ref, doc) in enumerate(tqdm(iterable_obj)):
480
+
481
+ if not pre_compute_embeddings: # Compute embeddings
482
+ ref_sentences = tokenize_and_prep_document(ref, tokenize_sentences)
483
+ pred_sentences = tokenize_and_prep_document(pred, tokenize_sentences)
484
+ doc_sentences = tokenize_and_prep_document(doc, tokenize_sentences)
485
+
486
+ # Compute Embeddings
487
+ doc_sentence_count = len(doc_sentences)
488
+ ref_sentence_count = len(ref_sentences)
489
+ all_sentences = doc_sentences + ref_sentences + pred_sentences
490
+ embeddings = encoder.encode(all_sentences)
491
+ doc_embeddings = embeddings[:doc_sentence_count]
492
+ ref_embeddings = embeddings[doc_sentence_count:doc_sentence_count + ref_sentence_count]
493
+ pred_embeddings = embeddings[doc_sentence_count + ref_sentence_count:]
494
+ else: # we already have embeddings
495
+ doc_embeddings = doc
496
+ ref_embeddings = ref
497
+ pred_embeddings = pred
498
+
499
+ doc_sentences = sent_tokenized_documents[idx]
500
+
501
+ # Compute Pair-Wise Cosine Similarity
502
+ ref_sim_scores = compute_cosine_similarity(doc_embeddings, ref_embeddings)
503
+ pred_sim_scores = compute_cosine_similarity(doc_embeddings, pred_embeddings)
504
+
505
+ # Compute Gains
506
+ ground_truth_gain = compute_gain(ref_sim_scores)
507
+
508
+ # this is used to compute top-predicted sentence indices
509
+ pred_gain = compute_gain(pred_sim_scores)
510
+ real_k = min(len(pred_gain), k)
511
+
512
+ # Compute NCG Scores
513
+ ncg_score = compute_ncg(pred_gain, ground_truth_gain, real_k)
514
+
515
+ if debug:
516
+ ground_truth_gain = [(doc_sentences[sent_idx], gain_val) for sent_idx, gain_val in ground_truth_gain]
517
+ pred_gain = [(doc_sentences[sent_idx], gain_val) for sent_idx, gain_val in pred_gain]
518
+ out.append(RankedGains(ground_truth_gain, pred_gain, k=real_k, ncg=ncg_score))
519
+ else:
520
+ out.append(ncg_score)
521
+
522
+ if debug:
523
+ return statistics.mean([ele.ncg for ele in out]), out
524
+
525
+ return statistics.mean(out), out
tests.py CHANGED
@@ -1,17 +1,418 @@
1
- test_cases = [
2
- {
3
- "predictions": [0, 0],
4
- "references": [1, 1],
5
- "result": {"metric_score": 0}
6
- },
7
- {
8
- "predictions": [1, 1],
9
- "references": [1, 1],
10
- "result": {"metric_score": 1}
11
- },
12
- {
13
- "predictions": [1, 0],
14
- "references": [1, 1],
15
- "result": {"metric_score": 0.5}
16
- }
17
- ]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import statistics
2
+ import unittest
3
+ from unittest.mock import patch, MagicMock
4
+
5
+ import numpy as np
6
+ import torch
7
+ from numpy.testing import assert_almost_equal
8
+ from sentence_transformers import SentenceTransformer
9
+ from sklearn.metrics.pairwise import cosine_similarity
10
+
11
+ from .encoder_models import SBertEncoder, get_encoder, get_sbert_encoder
12
+ from .semncg import RankedGains, compute_cosine_similarity, compute_gain, score_ncg, compute_ncg, _validate_input_format, SemnCG
13
+ from .utils import get_gpu, slice_embeddings, is_nested_list_of_type, flatten_list, prep_sentences, tokenize_and_prep_document
14
+
15
+
16
+ class TestUtils(unittest.TestCase):
17
+ def test_get_gpu(self):
18
+ gpu_count = torch.cuda.device_count()
19
+ gpu_available = torch.cuda.is_available()
20
+
21
+ # Test single boolean input
22
+ self.assertEqual(get_gpu(True), 0 if gpu_available else "cpu")
23
+ self.assertEqual(get_gpu(False), "cpu")
24
+
25
+ # Test single string input
26
+ self.assertEqual(get_gpu("cpu"), "cpu")
27
+ self.assertEqual(get_gpu("gpu"), 0 if gpu_available else "cpu")
28
+ self.assertEqual(get_gpu("cuda"), 0 if gpu_available else "cpu")
29
+
30
+ # Test single integer input
31
+ self.assertEqual(get_gpu(0), 0 if gpu_available else "cpu")
32
+ self.assertEqual(get_gpu(1), 1 if gpu_available else "cpu")
33
+
34
+ # Test list input with unique elements
35
+ self.assertEqual(get_gpu([True, "cpu", 0]), [0, "cpu"] if gpu_available else ["cpu", "cpu", "cpu"])
36
+
37
+ # Test list input with duplicate elements
38
+ self.assertEqual(get_gpu([0, 0, "gpu"]), 0 if gpu_available else ["cpu", "cpu", "cpu"])
39
+
40
+ # Test list input with duplicate elements of different types
41
+ self.assertEqual(get_gpu([True, 0, "gpu"]), 0 if gpu_available else ["cpu", "cpu", "cpu"])
42
+
43
+ # Test list input but only one element
44
+ self.assertEqual(get_gpu([True]), 0 if gpu_available else "cpu")
45
+
46
+ # Test list input with all integers
47
+ self.assertEqual(get_gpu(list(range(gpu_count))),
48
+ list(range(gpu_count)) if gpu_available else gpu_count * ["cpu"])
49
+
50
+ with self.assertRaises(ValueError):
51
+ get_gpu("invalid")
52
+
53
+ with self.assertRaises(ValueError):
54
+ get_gpu(torch.cuda.device_count())
55
+
56
+ def test_prep_sentences(self):
57
+ # Test normal case
58
+ self.assertEqual(prep_sentences(["Hello, world!", " This is a test. ", "!!!"]),
59
+ ['Hello, world!', 'This is a test.'])
60
+
61
+ # Test case with only punctuations
62
+ with self.assertRaises(ValueError):
63
+ prep_sentences(["!!!", "..."])
64
+
65
+ # Test case with empty list
66
+ with self.assertRaises(ValueError):
67
+ prep_sentences([])
68
+
69
+ def test_tokenize_and_prep_document(self):
70
+ # Test tokenize=True with string input
71
+ self.assertEqual(tokenize_and_prep_document("Hello, world! This is a test.", True),
72
+ ['Hello, world!', 'This is a test.'])
73
+
74
+ # Test tokenize=False with list of strings input
75
+ self.assertEqual(tokenize_and_prep_document(["Hello, world!", "This is a test."], False),
76
+ ['Hello, world!', 'This is a test.'])
77
+
78
+ # Test tokenize=True with empty document
79
+ with self.assertRaises(ValueError):
80
+ tokenize_and_prep_document("!!! ...", True)
81
+
82
+ def test_slice_embeddings(self):
83
+ # Case 1
84
+ embeddings = np.random.rand(10, 5)
85
+ num_sentences = [3, 2, 5]
86
+ expected_output = [embeddings[:3], embeddings[3:5], embeddings[5:]]
87
+ self.assertTrue(
88
+ all(np.array_equal(a, b) for a, b in zip(slice_embeddings(embeddings, num_sentences),
89
+ expected_output))
90
+ )
91
+
92
+ # Case 2
93
+ num_sentences_nested = [[2, 1], [3, 4]]
94
+ expected_output_nested = [[embeddings[:2], embeddings[2:3]], [embeddings[3:6], embeddings[6:]]]
95
+ self.assertTrue(
96
+ slice_embeddings(embeddings, num_sentences_nested), expected_output_nested
97
+ )
98
+
99
+ # Case 3
100
+ document_sentences_count = [10, 8, 7]
101
+ reference_sentences_count = [5, 3, 2]
102
+ pred_sentences_count = [2, 2, 1]
103
+ all_embeddings = np.random.rand(
104
+ sum(document_sentences_count + reference_sentences_count + pred_sentences_count), 5,
105
+ )
106
+
107
+ embeddings = all_embeddings
108
+ expected_doc_embeddings = [embeddings[:10], embeddings[10:18], embeddings[18:25]]
109
+
110
+ embeddings = all_embeddings[25:]
111
+ expected_ref_embeddings = [embeddings[:5], embeddings[5:8], embeddings[8:10]]
112
+
113
+ embeddings = all_embeddings[35:]
114
+ expected_pred_embeddings = [embeddings[:2], embeddings[2:4], embeddings[4:5]]
115
+
116
+ doc_embeddings = slice_embeddings(all_embeddings, document_sentences_count)
117
+ ref_embeddings = slice_embeddings(all_embeddings[sum(document_sentences_count):], reference_sentences_count)
118
+ pred_embeddings = slice_embeddings(
119
+ all_embeddings[sum(document_sentences_count+reference_sentences_count):], pred_sentences_count
120
+ )
121
+
122
+ self.assertTrue(doc_embeddings, expected_doc_embeddings)
123
+ self.assertTrue(ref_embeddings, expected_ref_embeddings)
124
+ self.assertTrue(pred_embeddings, expected_pred_embeddings)
125
+
126
+ with self.assertRaises(TypeError):
127
+ slice_embeddings(embeddings, "invalid")
128
+
129
+ def test_is_nested_list_of_type(self):
130
+ # Test case: Depth 0, single element matching element_type
131
+ self.assertTrue(is_nested_list_of_type("test", str, 0))
132
+
133
+ # Test case: Depth 0, single element not matching element_type
134
+ self.assertFalse(is_nested_list_of_type("test", int, 0))
135
+
136
+ # Test case: Depth 1, list of elements matching element_type
137
+ self.assertTrue(is_nested_list_of_type(["apple", "banana"], str, 1))
138
+
139
+ # Test case: Depth 1, list of elements not matching element_type
140
+ self.assertFalse(is_nested_list_of_type([1, 2, 3], str, 1))
141
+
142
+ # Test case: Depth 0 (Wrong), list of elements matching element_type
143
+ self.assertFalse(is_nested_list_of_type([1, 2, 3], str, 0))
144
+
145
+ # Depth 2
146
+ self.assertTrue(is_nested_list_of_type([[1, 2], [3, 4]], int, 2))
147
+ self.assertTrue(is_nested_list_of_type([['1', '2'], ['3', '4']], str, 2))
148
+ self.assertFalse(is_nested_list_of_type([[1, 2], ["a", "b"]], int, 2))
149
+
150
+ # Depth 3
151
+ self.assertFalse(is_nested_list_of_type([[[1], [2]], [[3], [4]]], list, 3))
152
+ self.assertTrue(is_nested_list_of_type([[[1], [2]], [[3], [4]]], int, 3))
153
+
154
+ with self.assertRaises(ValueError):
155
+ is_nested_list_of_type([1, 2], int, -1)
156
+
157
+ def test_flatten_list(self):
158
+ self.assertEqual(flatten_list([1, [2, 3], [[4], 5]]), [1, 2, 3, 4, 5])
159
+ self.assertEqual(flatten_list([]), [])
160
+ self.assertEqual(flatten_list([1, 2, 3]), [1, 2, 3])
161
+ self.assertEqual(flatten_list([[[[1]]]]), [1])
162
+
163
+
164
+ class TestSBertEncoder(unittest.TestCase):
165
+
166
+ def setUp(self) -> None:
167
+ # Set up a test SentenceTransformer model
168
+ self.model_name = "paraphrase-distilroberta-base-v1"
169
+ self.sbert_model = get_sbert_encoder(self.model_name)
170
+ self.device = "cpu" # For testing on CPU
171
+ self.batch_size = 32
172
+ self.verbose = False
173
+ self.encoder = SBertEncoder(self.sbert_model, self.device, self.batch_size, self.verbose)
174
+
175
+ def test_encode_single_sentence(self):
176
+ sentence = "Hello, world!"
177
+ embeddings = self.encoder.encode([sentence])
178
+ self.assertEqual(embeddings.shape, (1, 768)) # Adjust shape based on your model's embedding dimension
179
+
180
+ def test_encode_multiple_sentences(self):
181
+ sentences = ["Hello, world!", "This is a test."]
182
+ embeddings = self.encoder.encode(sentences)
183
+ self.assertEqual(embeddings.shape, (2, 768)) # Adjust shape based on your model's embedding dimension
184
+
185
+ def test_get_sbert_encoder(self):
186
+ model_name = "paraphrase-distilroberta-base-v1"
187
+ sbert_model = get_sbert_encoder(model_name)
188
+ self.assertIsInstance(sbert_model, SentenceTransformer)
189
+
190
+ def test_encode_with_gpu(self):
191
+ if torch.cuda.is_available():
192
+ device = "cuda"
193
+ encoder = get_encoder(self.sbert_model, device, self.batch_size, self.verbose)
194
+ sentences = ["Hello, world!", "This is a test."]
195
+ embeddings = encoder.encode(sentences)
196
+ self.assertEqual(embeddings.shape, (2, 768)) # Adjust shape based on your model's embedding dimension
197
+ else:
198
+ self.skipTest("CUDA not available, skipping GPU test.")
199
+
200
+ def test_encode_multi_device(self):
201
+ if torch.cuda.device_count() < 2:
202
+ self.skipTest("Multi-GPU test requires at least 2 GPUs.")
203
+ else:
204
+ devices = ["cuda:0", "cuda:1"]
205
+ encoder = get_encoder(self.sbert_model, devices, self.batch_size, self.verbose)
206
+ sentences = ["This is a test sentence.", "Here is another sentence.", "This is a test sentence."]
207
+ embeddings = encoder.encode(sentences)
208
+ self.assertIsInstance(embeddings, np.ndarray)
209
+ self.assertEqual(embeddings.shape[0], 3)
210
+ self.assertEqual(embeddings.shape[1], self.encoder.model.get_sentence_embedding_dimension())
211
+
212
+
213
+ class TestGetEncoder(unittest.TestCase):
214
+ def setUp(self):
215
+ self.device = "cuda" if torch.cuda.is_available() else "cpu"
216
+ self.batch_size = 8
217
+ self.verbose = False
218
+
219
+ def _base_test(self, model_name):
220
+ sbert_model = get_sbert_encoder(model_name)
221
+ encoder = get_encoder(sbert_model, self.device, self.batch_size, self.verbose)
222
+
223
+ # Assert
224
+ self.assertIsInstance(encoder, SBertEncoder)
225
+ self.assertEqual(encoder.device, self.device)
226
+ self.assertEqual(encoder.batch_size, self.batch_size)
227
+ self.assertEqual(encoder.verbose, self.verbose)
228
+
229
+ def test_get_sbert_encoder(self):
230
+ model_name = "stsb-roberta-large"
231
+ self._base_test(model_name)
232
+
233
+ def test_sbert_model(self):
234
+ model_name = "all-mpnet-base-v2"
235
+ self._base_test(model_name)
236
+
237
+ def test_huggingface_model(self):
238
+ """Test Huggingface models which work with SBert library"""
239
+ model_name = "roberta-base"
240
+ self._base_test(model_name)
241
+
242
+ def test_get_encoder_environment_error(self): # This parameter is used when using patch decorator
243
+ model_name = "abc" # Wrong model_name
244
+ with self.assertRaises(EnvironmentError):
245
+ get_sbert_encoder(model_name)
246
+
247
+ def test_get_encoder_other_exception(self):
248
+ model_name = "apple/OpenELM-270M" # This model is not supported by SentenceTransformer lib
249
+ with self.assertRaises(RuntimeError):
250
+ get_sbert_encoder(model_name)
251
+
252
+
253
+ class TestRankedGainsDataclass(unittest.TestCase):
254
+ def test_ranked_gains_dataclass(self):
255
+ # Test initialization and attribute access
256
+ gt_gains = [("doc1", 0.8), ("doc2", 0.6)]
257
+ pred_gains = [("doc2", 0.7), ("doc1", 0.5)]
258
+ k = 2
259
+ ncg = 0.75
260
+ ranked_gains = RankedGains(gt_gains, pred_gains, k, ncg)
261
+
262
+ self.assertEqual(ranked_gains.gt_gains, gt_gains)
263
+ self.assertEqual(ranked_gains.pred_gains, pred_gains)
264
+ self.assertEqual(ranked_gains.k, k)
265
+ self.assertEqual(ranked_gains.ncg, ncg)
266
+
267
+
268
+ class TestComputeCosineSimilarity(unittest.TestCase):
269
+ def test_compute_cosine_similarity(self):
270
+ doc_embeds = np.array([[0.1, 0.2, 0.3], [0.4, 0.5, 0.6]])
271
+ ref_embeds = np.array([[0.2, 0.3, 0.4], [0.5, 0.6, 0.7]])
272
+ # Test compute_cosine_similarity function
273
+ similarity_scores = compute_cosine_similarity(doc_embeds, ref_embeds)
274
+ print(similarity_scores)
275
+
276
+ # Example values, change as per actual function output
277
+ expected_scores = [0.980, 0.997]
278
+
279
+ self.assertAlmostEqual(similarity_scores[0], expected_scores[0], places=3)
280
+ self.assertAlmostEqual(similarity_scores[1], expected_scores[1], places=3)
281
+
282
+
283
+ class TestComputeGain(unittest.TestCase):
284
+ def test_compute_gain(self):
285
+ # Test compute_gain function
286
+ sim_scores = [0.8, 0.6, 0.7]
287
+ gains = compute_gain(sim_scores)
288
+ print(gains)
289
+
290
+ # Example values, change as per actual function output
291
+ expected_gains = [(0, 0.5), (2, 0.3333333333333333), (1, 0.16666666666666666)]
292
+
293
+ self.assertEqual(gains, expected_gains)
294
+
295
+
296
+ class TestScoreNcg(unittest.TestCase):
297
+ def test_score_ncg(self):
298
+ # Test score_ncg function
299
+ model_relevance = [0.8, 0.7, 0.6]
300
+ gt_relevance = [1.0, 0.9, 0.8]
301
+ ncg_score = score_ncg(model_relevance, gt_relevance)
302
+ expected_ncg = 0.778 # Example value, change as per actual function output
303
+
304
+ self.assertAlmostEqual(ncg_score, expected_ncg, places=3)
305
+
306
+
307
+ class TestComputeNcg(unittest.TestCase):
308
+ def test_compute_ncg(self):
309
+ # Test compute_ncg function
310
+ pred_gains = [(0, 0.8), (2, 0.7), (1, 0.6)]
311
+ gt_gains = [(0, 1.0), (1, 0.9), (2, 0.8)]
312
+ k = 3
313
+ ncg_score = compute_ncg(pred_gains, gt_gains, k)
314
+ expected_ncg = 1.0 # TODO: Confirm this with Dr. Santu
315
+
316
+ self.assertAlmostEqual(ncg_score, expected_ncg, places=6)
317
+
318
+
319
+ class TestValidateInputFormat(unittest.TestCase):
320
+ def test_validate_input_format(self):
321
+ # Test _validate_input_format function
322
+ tokenize_sentences = True
323
+ predictions = ["Prediction 1", "Prediction 2"]
324
+ references = ["Reference 1", "Reference 2"]
325
+ documents = ["Document 1", "Document 2"]
326
+
327
+ # No exception should be raised for valid input
328
+ try:
329
+ _validate_input_format(tokenize_sentences, predictions, references, documents)
330
+ except ValueError as e:
331
+ self.fail(f"_validate_input_format raised ValueError unexpectedly: {str(e)}")
332
+
333
+ # Test invalid input format
334
+ predictions_invalid = [["Sentence 1 in prediction 1.", "Sentence 2 in prediction 1."],
335
+ ["Sentence 1 in prediction 2.", "Sentence 2 in prediction 2."]]
336
+ references_invalid = [["Sentences in reference 1."], ["Sentences in reference 2."]]
337
+ documents_invalid = [["Sentence 1 in document 1.", "Sentence 2 in document 1."],
338
+ ["Sentence 1 in document 2.", "Sentence 2 in document 2."]]
339
+
340
+ with self.assertRaises(ValueError):
341
+ _validate_input_format(tokenize_sentences, predictions_invalid, references, documents)
342
+
343
+ with self.assertRaises(ValueError):
344
+ _validate_input_format(tokenize_sentences, predictions, references_invalid, documents)
345
+
346
+ with self.assertRaises(ValueError):
347
+ _validate_input_format(tokenize_sentences, predictions, references, documents_invalid)
348
+
349
+
350
+ class TestSemnCG(unittest.TestCase):
351
+ def setUp(self):
352
+ self.model_name = "stsb-distilbert-base"
353
+ self.metric = SemnCG(self.model_name)
354
+
355
+ def _basic_assertion(self, result, debug: bool = False):
356
+ self.assertIsInstance(result, tuple)
357
+ self.assertEqual(len(result), 2)
358
+ self.assertIsInstance(result[0], float)
359
+ self.assertTrue(0.0 <= result[0] <= 1.0)
360
+ self.assertIsInstance(result[1], list)
361
+ if debug:
362
+ for ranked_gain in result[1]:
363
+ self.assertTrue(isinstance(ranked_gain, RankedGains))
364
+ self.assertTrue(0.0 <= ranked_gain.ncg <= 1.0)
365
+ else:
366
+ for gain in result[1]:
367
+ self.assertTrue(isinstance(gain, float))
368
+ self.assertTrue(0.0 <= gain <= 1.0)
369
+
370
+ def test_compute_basic(self):
371
+ predictions = ["The cat sat on the mat.", "The quick brown fox jumps over the lazy dog."]
372
+ references = ["A cat was sitting on a mat.", "A quick brown fox jumped over a lazy dog."]
373
+ documents = ["There was a cat on a mat.", "The quick brown fox jumped over the lazy dog."]
374
+
375
+ result = self.metric.compute(predictions=predictions, references=references, documents=documents)
376
+ self._basic_assertion(result)
377
+
378
+ def test_compute_with_tokenization(self):
379
+ predictions = [["The cat sat on the mat."], ["The quick brown fox jumps over the lazy dog."]]
380
+ references = [["A cat was sitting on a mat."], ["A quick brown fox jumped over a lazy dog."]]
381
+ documents = [["There was a cat on a mat."], ["The quick brown fox jumped over the lazy dog."]]
382
+
383
+ result = self.metric.compute(
384
+ predictions=predictions, references=references, documents=documents, tokenize_sentences=False
385
+ )
386
+ self._basic_assertion(result)
387
+
388
+ def test_compute_with_pre_compute_embeddings(self):
389
+ predictions = ["The cat sat on the mat.", "The quick brown fox jumps over the lazy dog."]
390
+ references = ["A cat was sitting on a mat.", "A quick brown fox jumped over a lazy dog."]
391
+ documents = ["There was a cat on a mat.", "The quick brown fox jumped over the lazy dog."]
392
+
393
+ result = self.metric.compute(
394
+ predictions=predictions, references=references, documents=documents, pre_compute_embeddings=True
395
+ )
396
+ self._basic_assertion(result)
397
+
398
+ def test_compute_with_debug(self):
399
+ predictions = ["The cat sat on the mat.", "The quick brown fox jumps over the lazy dog."]
400
+ references = ["A cat was sitting on a mat.", "A quick brown fox jumped over a lazy dog."]
401
+ documents = ["There was a cat on a mat.", "The quick brown fox jumped over the lazy dog."]
402
+
403
+ result = self.metric.compute(
404
+ predictions=predictions, references=references, documents=documents, debug=True
405
+ )
406
+ self._basic_assertion(result, debug=True)
407
+
408
+ def test_compute_invalid_input_format(self):
409
+ predictions = "The cat sat on the mat."
410
+ references = ["A cat was sitting on a mat."]
411
+ documents = ["There was a cat on a mat."]
412
+
413
+ with self.assertRaises(ValueError):
414
+ self.metric.compute(predictions=predictions, references=references, documents=documents)
415
+
416
+
417
+ if __name__ == '__main__':
418
+ unittest.main(verbosity=2)
type_aliases.py ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ from typing import List, Union, Tuple
3
+
4
+ from numpy.typing import NDArray
5
+
6
+
7
+ NumSentencesType = Union[List[int], List[List[int]]]
8
+ EmbeddingSlicesType = Union[List[NDArray], List[List[NDArray]]]
9
+ DEVICE_TYPE = Union[bool, str, int, List[Union[str, int]]]
10
+ ENCODER_DEVICE_TYPE = Union[str, int, List[Union[str, int]]]
11
+ DOCUMENT_TYPE = Union[List[str], List[List[str]]]
utils.py ADDED
@@ -0,0 +1,280 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import string
3
+ from typing import List, Tuple, Union
4
+
5
+ import nltk
6
+ import numpy as np
7
+ from numpy.typing import NDArray
8
+ import torch
9
+
10
+ from .type_aliases import DEVICE_TYPE, ENCODER_DEVICE_TYPE, NumSentencesType, EmbeddingSlicesType
11
+
12
+
13
+ def get_gpu(gpu: DEVICE_TYPE) -> ENCODER_DEVICE_TYPE:
14
+ """
15
+ Determine the correct GPU device based on the provided input. In the following, output 0 means CUDA device 0.
16
+
17
+ Args:
18
+ gpu (Union[bool, str, int, List[Union[str, int]]]): Input specifying the GPU device(s):
19
+ - bool: If True, returns 0 if CUDA is available, otherwise returns "cpu".
20
+ - str: Can be "cpu", "gpu", or "cuda" (case-insensitive). Returns 0 if CUDA is available
21
+ and the input is not "cpu", otherwise returns "cpu".
22
+ - int: Should be a valid GPU index. Returns the index if CUDA is available and valid,
23
+ otherwise returns "cpu".
24
+ - List[Union[str, int]]: List containing combinations of the str/int. Processes each
25
+ element and returns a list of corresponding results.
26
+
27
+ Returns:
28
+ Union[str, int, List[Union[str, int]]]: Depending on the input type:
29
+ - str: Returns "cpu" if no GPU is available or the input is "cpu".
30
+ - int: Returns the GPU index if valid and CUDA is available.
31
+ - List[Union[str, int]]: Returns a list of strings and/or integers based on the input list.
32
+
33
+ Raises:
34
+ ValueError: If the input gpu type is not recognized or invalid.
35
+ ValueError: If a string input is not one of ["cpu", "gpu", "cuda"].
36
+ ValueError: If an integer input is outside the valid range of GPU indices.
37
+
38
+ Notes:
39
+ - This function checks CUDA availability using torch.cuda.is_available() and counts
40
+ available GPUs using torch.cuda.device_count().
41
+ - Case insensitivity is maintained for string inputs ("cpu", "gpu", "cuda").
42
+ - The function ensures robust error handling for invalid input types or out-of-range indices.
43
+ """
44
+
45
+ # Ensure gpu index is within the range of total available gpus
46
+ gpu_available = torch.cuda.is_available()
47
+ gpu_count = torch.cuda.device_count()
48
+ correct_strs = ["cpu", "gpu", "cuda"]
49
+
50
+ def _get_single_device(gpu_item):
51
+ if isinstance(gpu_item, bool):
52
+ return 0 if gpu_item and gpu_available else "cpu"
53
+ elif isinstance(gpu_item, str):
54
+ if gpu_item.lower() not in correct_strs:
55
+ raise ValueError(f"Wrong gpu type: {gpu_item}. Should be one of {correct_strs}")
56
+ return 0 if (gpu_item.lower() != "cpu") and gpu_available else "cpu"
57
+ elif isinstance(gpu_item, int):
58
+ if gpu_item >= gpu_count:
59
+ raise ValueError(
60
+ f"There are {gpu_count} GPUs available. Provide a valid GPU index. You provided: {gpu_item}"
61
+ )
62
+ return gpu_item if gpu_available else "cpu"
63
+ else:
64
+ raise ValueError(f"Invalid gpu type: {type(gpu_item)}. Must be bool, str, or int.")
65
+
66
+ if isinstance(gpu, list):
67
+ seen_indices = set()
68
+ result = []
69
+ for item in gpu:
70
+ device = _get_single_device(item)
71
+ if isinstance(device, int):
72
+ if device not in seen_indices:
73
+ seen_indices.add(device)
74
+ result.append(device)
75
+ else:
76
+ result.append(device)
77
+ return result[0] if len(result) == 1 else result
78
+ else:
79
+ return _get_single_device(gpu)
80
+
81
+
82
+ def prep_sentences(sentences: List[str]) -> List[str]:
83
+ """
84
+ Processes a list of sentences by stripping whitespace (at beginning and the end),
85
+ , filtering out empty sentences or sentences that only contains punctuations.
86
+
87
+ Args:
88
+ sentences (List[str]): A list of sentences to be processed.
89
+
90
+ Returns:
91
+ List[str]: A list of cleaned sentences
92
+
93
+ Raises:
94
+ ValueError: If the resulting list of sentences is empty.
95
+
96
+ Example:
97
+ >>> prep_sentences(["Hello, world!", " This is a test. ", "!!!"])
98
+ ['Hello, world!', 'This is a test.']
99
+
100
+ >>> prep_sentences(["!!!", "..."])
101
+ ValueError: Document can't be empty.
102
+ """
103
+ out = []
104
+ for sent in sentences:
105
+ sent = sent.strip()
106
+ sent_wo_punctuation = (
107
+ sent.translate(str.maketrans("", "", string.punctuation))
108
+ ).strip()
109
+ if sent_wo_punctuation:
110
+ out.append(sent)
111
+
112
+ if len(out) == 0:
113
+ raise ValueError("Document can't be empty.")
114
+ return out
115
+
116
+
117
+ def tokenize_and_prep_document(document: Union[str, List[str]], tokenize: bool) -> List[str]:
118
+ """
119
+ Tokenizes and prepares a document by either tokenizing it into sentences and processing each sentence,
120
+ or directly processing each element if `tokenize` is False.
121
+
122
+ Args:
123
+ document (Union[str, List[str]]): The document to be processed. It can be a single string (enitre document) or a
124
+ list of strings (list of sentences).
125
+ tokenize (bool): If True, tokenizes `document` into sentences using NLTK's sentence tokenizer before processing.
126
+ If False, processes each element of `document` directly as sentences.
127
+
128
+ Returns:
129
+ List[str]: A list of cleaned sentences.
130
+
131
+ Raises:
132
+ ValueError: If the resulting list of sentences is empty after processing.
133
+
134
+ Example:
135
+ >>> tokenize_and_prep_document("Hello, world! This is a test.", True)
136
+ ['Hello, world!', 'This is a test.']
137
+
138
+ >>> tokenize_and_prep_document(["Hello, world!", "This is a test."], False)
139
+ ['Hello, world!', 'This is a test.']
140
+
141
+ >>> tokenize_and_prep_document("!!! ...", True)
142
+ ValueError: Document can't be empty.
143
+
144
+ Note: Only the following two cases are possible.
145
+ tokenizer=True -> document: str
146
+ tokenizer=False -> document: List[str].
147
+ """
148
+ if tokenize:
149
+ return prep_sentences(nltk.tokenize.sent_tokenize(document))
150
+ return prep_sentences(document)
151
+
152
+
153
+ def flatten_list(nested_list: list) -> list:
154
+ """
155
+ Recursively flattens a nested list of any depth.
156
+
157
+ Parameters:
158
+ nested_list (list): The nested list to flatten.
159
+
160
+ Returns:
161
+ list: A flat list containing all the elements of the nested list.
162
+ """
163
+ flat_list = []
164
+ for item in nested_list:
165
+ if isinstance(item, list):
166
+ flat_list.extend(flatten_list(item))
167
+ else:
168
+ flat_list.append(item)
169
+ return flat_list
170
+
171
+
172
+ def is_nested_list_of_type(lst_obj, element_type, depth: int) -> bool:
173
+ """
174
+ Check if the given object is a nested list of a specific type up to a specified depth.
175
+
176
+ Args:
177
+ - lst_obj: The object to check, expected to be a list or a single element.
178
+ - element_type: The type that each element in the nested list should match.
179
+ - depth (int): The depth of nesting to check. Must be non-negative.
180
+
181
+ Returns:
182
+ - bool: True if lst_obj is a nested list of the specified type up to the given depth, False otherwise.
183
+
184
+ Raises:
185
+ - ValueError: If depth is negative.
186
+
187
+ Example:
188
+ ```python
189
+ # Test cases
190
+ is_nested_list_of_type("test", str, 0) # Returns True
191
+ is_nested_list_of_type([1, 2, 3], str, 0) # Returns False
192
+ is_nested_list_of_type(["apple", "banana"], str, 1) # Returns True
193
+ is_nested_list_of_type([[1, 2], [3, 4]], int, 2) # Returns True
194
+ is_nested_list_of_type([[1, 2], ["a", "b"]], int, 2) # Returns False
195
+ is_nested_list_of_type([[[1], [2]], [[3], [4]]], int, 3) # Returns True
196
+ ```
197
+
198
+ Explanation:
199
+ - The function checks if `lst_obj` is a nested list of elements of type `element_type` up to `depth` levels deep.
200
+ - If `depth` is 0, it checks if `lst_obj` itself is of type `element_type`.
201
+ - If `depth` is greater than 0, it recursively checks each level of nesting to ensure all elements match `element_type`.
202
+ - Raises a `ValueError` if `depth` is negative, as depth must be a non-negative integer.
203
+ """
204
+ if depth == 0:
205
+ return isinstance(lst_obj, element_type)
206
+ elif depth > 0:
207
+ return isinstance(lst_obj, list) and all(is_nested_list_of_type(item, element_type, depth - 1) for item in lst_obj)
208
+ else:
209
+ raise ValueError("Depth can't be negative")
210
+
211
+
212
+ def slice_embeddings(embeddings: NDArray, num_sentences: NumSentencesType) -> EmbeddingSlicesType:
213
+ """
214
+ Slice embeddings into segments based on the provided number of sentences per segment.
215
+
216
+ Args:
217
+ - embeddings (np.ndarray): The array of embeddings to be sliced.
218
+ - num_sentences (Union[List[int], List[List[int]]]):
219
+ - If a list of integers: Specifies the number of embeddings to take in each slice.
220
+ - If a list of lists of integers: Specifies multiple nested levels of slicing.
221
+
222
+ Returns:
223
+ - List[np.ndarray]: A list of numpy arrays where each array represents a slice of embeddings.
224
+
225
+ Raises:
226
+ - TypeError: If `num_sentences` is not of type List[int] or List[List[int]].
227
+
228
+ Example Usage:
229
+
230
+ ```python
231
+ embeddings = np.random.rand(10, 5)
232
+ num_sentences = [3, 2, 5]
233
+ result = slice_embeddings(embeddings, num_sentences)
234
+ # `result` will be a list of numpy arrays:
235
+ # [embeddings[:3], embeddings[3:5], embeddings[5:]]
236
+
237
+ num_sentences_nested = [[2, 1], [3, 4]]
238
+ result_nested = slice_embeddings(embeddings, num_sentences_nested)
239
+ # `result_nested` will be a nested list of numpy arrays:
240
+ # [[embeddings[:2], embeddings[2:3]], [embeddings[3:6], embeddings[6:]]]
241
+
242
+ slice_embeddings(embeddings, "invalid") # Raises a TypeError
243
+ ```
244
+ """
245
+
246
+ def _slice_embeddings(s_idx: int, n_sentences: List[int]):
247
+ """
248
+ Helper function to slice embeddings starting from index `s_idx`.
249
+
250
+ Args:
251
+ - s_idx (int): Starting index for slicing.
252
+ - n_sentences (List[int]): List specifying number of sentences in each slice.
253
+
254
+ Returns:
255
+ - Tuple[List[np.ndarray], int]: A tuple containing a list of sliced embeddings and the next starting index.
256
+ """
257
+ _result = []
258
+ for count in n_sentences:
259
+ _result.append(embeddings[s_idx:s_idx + count])
260
+ s_idx += count
261
+ return _result, s_idx
262
+
263
+ if isinstance(num_sentences, list) and all(isinstance(item, int) for item in num_sentences):
264
+ result, _ = _slice_embeddings(0, num_sentences)
265
+ return result
266
+ elif isinstance(num_sentences, list) and all(
267
+ isinstance(sublist, list) and all(
268
+ isinstance(item, int) for item in sublist
269
+ )
270
+ for sublist in num_sentences
271
+ ):
272
+ nested_result = []
273
+ start_idx = 0
274
+ for nested_num_sentences in num_sentences:
275
+ embedding_slice, start_idx = _slice_embeddings(start_idx, nested_num_sentences)
276
+ nested_result.append(embedding_slice)
277
+
278
+ return nested_result
279
+ else:
280
+ raise TypeError(f"Incorrect Type for {num_sentences=}")