Spaces:
Sleeping
Sleeping
Added SemNCG metric
Browse files- .gitignore +1 -0
- README.md +85 -20
- __init__.py +0 -0
- encoder_models.py +129 -0
- requirements.txt +3 -1
- semncg.py +475 -45
- tests.py +418 -17
- type_aliases.py +11 -0
- utils.py +280 -0
.gitignore
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
__pycache__/
|
README.md
CHANGED
@@ -5,46 +5,111 @@ datasets:
|
|
5 |
tags:
|
6 |
- evaluate
|
7 |
- metric
|
8 |
-
description: "
|
|
|
|
|
|
|
|
|
9 |
sdk: gradio
|
10 |
sdk_version: 3.19.1
|
11 |
app_file: app.py
|
12 |
pinned: false
|
13 |
---
|
14 |
|
15 |
-
# Metric Card for
|
16 |
-
|
17 |
-
***Module Card Instructions:*** *Fill out the following subsections. Feel free to take a look at existing metric cards if you'd like examples.*
|
18 |
|
19 |
## Metric Description
|
20 |
-
|
|
|
|
|
|
|
|
|
21 |
|
22 |
## How to Use
|
23 |
-
*Give general statement of how to use the metric*
|
24 |
|
25 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
26 |
|
27 |
-
|
28 |
-
|
29 |
-
|
|
|
|
|
|
|
30 |
|
31 |
### Output Values
|
32 |
|
33 |
-
|
34 |
|
35 |
-
|
|
|
36 |
|
37 |
-
#### Values from Popular Papers
|
38 |
-
*Give examples, preferrably with links to leaderboards or publications, to papers that have reported this metric, along with the values they have reported.*
|
39 |
|
40 |
-
|
41 |
-
|
|
|
|
|
42 |
|
43 |
-
##
|
44 |
-
*Note any known limitations or biases that the metric has, with links and references if possible.*
|
45 |
|
46 |
## Citation
|
47 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
48 |
|
49 |
## Further References
|
50 |
-
|
|
|
|
5 |
tags:
|
6 |
- evaluate
|
7 |
- metric
|
8 |
+
description: "Sem-nCG (Semantic Normalized Cumulative Gain) Metric evaluates the quality of predicted sentences
|
9 |
+
(abstractive/extractive) in relation to reference sentences and documents using Semantic Normalized Cumulative Gain
|
10 |
+
(NCG). It computes gain values and NCG scores based on cosine similarity between sentence embeddings, leveraging a
|
11 |
+
Sentence-BERT encoder. This metric is designed to assess the relevance and ranking of predicted sentences, making it
|
12 |
+
useful for tasks such as summarization and information retrieval."
|
13 |
sdk: gradio
|
14 |
sdk_version: 3.19.1
|
15 |
app_file: app.py
|
16 |
pinned: false
|
17 |
---
|
18 |
|
19 |
+
# Metric Card for Sem-nCG
|
|
|
|
|
20 |
|
21 |
## Metric Description
|
22 |
+
Sem-nCG (Semantic Normalized Cumulative Gain) metric evaluates system-generated summaries (`predictions`) by comparing
|
23 |
+
them with ground truth reference summaries (`references`) and input documents (`documents`). It computes the Semantic
|
24 |
+
Normalized Cumulative Gain (NCG) scores based on sentence embeddings, which assess the quality of summaries by
|
25 |
+
evaluating the relevance of predicted sentences to the reference and input document sentences.
|
26 |
+
|
27 |
|
28 |
## How to Use
|
|
|
29 |
|
30 |
+
Before using this metric, you need to install the dependencies:
|
31 |
+
```bash
|
32 |
+
pip install -U sentence-transformers nltk
|
33 |
+
```
|
34 |
+
|
35 |
+
Sem-nCG takes three mandatory arguments:
|
36 |
+
- `predictions` - List of predictions
|
37 |
+
- `references` - List of references
|
38 |
+
- `documents` - List of input documents
|
39 |
+
|
40 |
+
```python
|
41 |
+
from evaluate import load
|
42 |
+
predictions = [
|
43 |
+
"This is a prediction1 sentence 1. This is a prediction1 sentence 2.",
|
44 |
+
"This is a prediction2 sentence 1."
|
45 |
+
]
|
46 |
+
references = [
|
47 |
+
"This is a reference1 sentence 1. This is a reference1 sentence 2.",
|
48 |
+
"This is a reference2 sentence 1. This is a reference2 sentence 2."
|
49 |
+
]
|
50 |
+
documents = [
|
51 |
+
"This is a document1 sentence 1. This is a document1 sentence 2. This is a document1 sentence 3.",
|
52 |
+
"This is a document2 sentence 1. This is a document2 sentence 2."
|
53 |
+
]
|
54 |
+
model_name = "all-MiniLM-L6-v2"
|
55 |
+
metric = load("nbansal/semncg", model_name=model_name) # model_name is optional. Default=all-MiniLM-L6-v2
|
56 |
+
mean_score, scores = metric.compute(predictions=predictions, references=references, documents=documents)
|
57 |
+
print(f"Mean SemnCG: {mean_score}")
|
58 |
+
```
|
59 |
+
|
60 |
+
Sem-nCG also accepts several optional arguments:
|
61 |
+
- `tokenize_sentences (bool)`: Flag to indicate whether to tokenize the sentences in the input documents. Default: True
|
62 |
+
- `pre_compute_embeddings (bool)`: Flag to indicate whether to pre-compute embeddings for all sentences. Default=False
|
63 |
+
- `k (int)`: The rank threshold used for evaluating gains (typically top-k sentences). Default is 3.
|
64 |
+
- `gpu (Union[bool, str, int, List[Union[str, int]]])`: Whether to use GPU, CPU, or multiple processes for computation.
|
65 |
+
- `batch_size (int)`: Batch size for encoding. Default is 32.
|
66 |
+
- `verbose (bool)`: Flag to indicate verbose output. Default is False.
|
67 |
+
- `debug (bool)`: Flag to return detailed debug information including ranked gains. Default is False.
|
68 |
|
69 |
+
Refer to the inputs descriptions for more detailed usage as follows:
|
70 |
+
```python
|
71 |
+
import evaluate
|
72 |
+
metric = evaluate.load("nbansal/semncg")
|
73 |
+
print(metric.inputs_description)
|
74 |
+
```
|
75 |
|
76 |
### Output Values
|
77 |
|
78 |
+
The output is a tuple containing:
|
79 |
|
80 |
+
Mean Sem-nCG score: float: The average Sem-nCG score.
|
81 |
+
scores: List[Union[float, RankedGains]]: List of Sem-nCG scores or RankedGains objects for each document.
|
82 |
|
|
|
|
|
83 |
|
84 |
+
## Extensions
|
85 |
+
The current implementation supports any model from Huggingface/SentenceTransformer that is compatible with
|
86 |
+
SentenceTransformer, such as `all-mpnet-base-v2` or `roberta-base`. You can extend the metric with more models by
|
87 |
+
extending the `Encoder` base class in the `encoder_models.py` file.
|
88 |
|
89 |
+
## Deviations from Published Methodology
|
|
|
90 |
|
91 |
## Citation
|
92 |
+
```bibtex
|
93 |
+
@inproceedings{akter-etal-2022-revisiting,
|
94 |
+
title = "Revisiting Automatic Evaluation of Extractive Summarization Task: Can We Do Better than {ROUGE}?",
|
95 |
+
author = "Akter, Mousumi and
|
96 |
+
Bansal, Naman and
|
97 |
+
Karmaker, Shubhra Kanti",
|
98 |
+
editor = "Muresan, Smaranda and
|
99 |
+
Nakov, Preslav and
|
100 |
+
Villavicencio, Aline",
|
101 |
+
booktitle = "Findings of the Association for Computational Linguistics: ACL 2022",
|
102 |
+
month = may,
|
103 |
+
year = "2022",
|
104 |
+
address = "Dublin, Ireland",
|
105 |
+
publisher = "Association for Computational Linguistics",
|
106 |
+
url = "https://aclanthology.org/2022.findings-acl.122",
|
107 |
+
doi = "10.18653/v1/2022.findings-acl.122",
|
108 |
+
pages = "1547--1560",
|
109 |
+
abstract = "It has been the norm for a long time to evaluate automated summarization tasks using the popular ROUGE metric. Although several studies in the past have highlighted the limitations of ROUGE, researchers have struggled to reach a consensus on a better alternative until today. One major limitation of the traditional ROUGE metric is the lack of semantic understanding (relies on direct overlap of n-grams). In this paper, we exclusively focus on the extractive summarization task and propose a semantic-aware nCG (normalized cumulative gain)-based evaluation metric (called Sem-nCG) for evaluating this task. One fundamental contribution of the paper is that it demonstrates how we can generate more reliable semantic-aware ground truths for evaluating extractive summarization tasks without any additional human intervention. To the best of our knowledge, this work is the first of its kind. We have conducted extensive experiments with this new metric using the widely used CNN/DailyMail dataset. Experimental results show that the new Sem-nCG metric is indeed semantic-aware, shows higher correlation with human judgement (more reliable) and yields a large number of disagreements with the original ROUGE metric (suggesting that ROUGE often leads to inaccurate conclusions also verified by humans).",
|
110 |
+
}
|
111 |
+
```
|
112 |
|
113 |
## Further References
|
114 |
+
- [Paper](https://aclanthology.org/2022.findings-acl.122/)
|
115 |
+
- [Video](https://underline.io/lecture/50182-findings-revisiting-automatic-evaluation-of-extractive-summarization-task-can-we-do-better-than-rougequestion)
|
__init__.py
ADDED
File without changes
|
encoder_models.py
ADDED
@@ -0,0 +1,129 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import abc
|
2 |
+
from typing import List, Union
|
3 |
+
|
4 |
+
from numpy.typing import NDArray
|
5 |
+
from sentence_transformers import SentenceTransformer
|
6 |
+
|
7 |
+
from .type_aliases import ENCODER_DEVICE_TYPE
|
8 |
+
|
9 |
+
|
10 |
+
class Encoder(abc.ABC):
|
11 |
+
@abc.abstractmethod
|
12 |
+
def encode(self, prediction: List[str]) -> NDArray:
|
13 |
+
"""
|
14 |
+
Abstract method to encode a list of sentences into sentence embeddings.
|
15 |
+
|
16 |
+
Args:
|
17 |
+
prediction (List[str]): List of sentences to encode.
|
18 |
+
|
19 |
+
Returns:
|
20 |
+
NDArray: Array of sentence embeddings with shape (num_sentences, embedding_dim).
|
21 |
+
|
22 |
+
Raises:
|
23 |
+
NotImplementedError: If the method is not implemented in the subclass.
|
24 |
+
"""
|
25 |
+
raise NotImplementedError("Method 'encode' must be implemented in subclass.")
|
26 |
+
|
27 |
+
|
28 |
+
class SBertEncoder(Encoder):
|
29 |
+
def __init__(self, model: SentenceTransformer, device: ENCODER_DEVICE_TYPE, batch_size: int, verbose: bool):
|
30 |
+
"""
|
31 |
+
Initialize SBertEncoder instance.
|
32 |
+
|
33 |
+
Args:
|
34 |
+
model (SentenceTransformer): The Sentence Transformer model instance to use for encoding.
|
35 |
+
device (Union[str, int, List[Union[str, int]]]): Device specification for encoding
|
36 |
+
batch_size (int): Batch size for encoding.
|
37 |
+
verbose (bool): Whether to print verbose information during encoding.
|
38 |
+
"""
|
39 |
+
self.model = model
|
40 |
+
self.device = device
|
41 |
+
self.batch_size = batch_size
|
42 |
+
self.verbose = verbose
|
43 |
+
|
44 |
+
def encode(self, prediction: List[str]) -> NDArray:
|
45 |
+
"""
|
46 |
+
Encode a list of sentences into sentence embeddings.
|
47 |
+
|
48 |
+
Args:
|
49 |
+
prediction (List[str]): List of sentences to encode.
|
50 |
+
|
51 |
+
Returns:
|
52 |
+
NDArray: Array of sentence embeddings with shape (num_sentences, embedding_dim).
|
53 |
+
"""
|
54 |
+
|
55 |
+
# SBert output is always Batch x Dim
|
56 |
+
if isinstance(self.device, list):
|
57 |
+
# Use multiprocess encoding for list of devices
|
58 |
+
pool = self.model.start_multi_process_pool(target_devices=self.device)
|
59 |
+
embeddings = self.model.encode_multi_process(prediction, pool=pool, batch_size=self.batch_size)
|
60 |
+
self.model.stop_multi_process_pool(pool)
|
61 |
+
else:
|
62 |
+
# Single device encoding
|
63 |
+
embeddings = self.model.encode(
|
64 |
+
prediction,
|
65 |
+
device=self.device,
|
66 |
+
batch_size=self.batch_size,
|
67 |
+
show_progress_bar=self.verbose,
|
68 |
+
)
|
69 |
+
|
70 |
+
return embeddings
|
71 |
+
|
72 |
+
|
73 |
+
def get_encoder(
|
74 |
+
sbert_model: SentenceTransformer,
|
75 |
+
device: ENCODER_DEVICE_TYPE,
|
76 |
+
batch_size: int,
|
77 |
+
verbose: bool,
|
78 |
+
) -> Encoder:
|
79 |
+
"""
|
80 |
+
Get an instance of SBertEncoder using the provided parameters.
|
81 |
+
|
82 |
+
Args:
|
83 |
+
sbert_model (SentenceTransformer): An instance of SentenceTransformer model to use for encoding.
|
84 |
+
device (Union[str, int, List[Union[str, int]]): Device specification for the encoder
|
85 |
+
(e.g., "cuda", 0 for GPU, "cpu").
|
86 |
+
batch_size (int): Batch size to use for encoding.
|
87 |
+
verbose (bool): Whether to print verbose information during encoding.
|
88 |
+
|
89 |
+
Returns:
|
90 |
+
SBertEncoder: Instance of the selected encoder based on the model_name.
|
91 |
+
|
92 |
+
Example:
|
93 |
+
>>> model_name = "paraphrase-distilroberta-base-v1"
|
94 |
+
>>> sbert_model = get_sbert_encoder(model_name)
|
95 |
+
>>> device = get_gpu("cuda")
|
96 |
+
>>> batch_size = 32
|
97 |
+
>>> verbose = True
|
98 |
+
>>> encoder = get_encoder(sbert_model, device, batch_size, verbose)
|
99 |
+
"""
|
100 |
+
encoder = SBertEncoder(sbert_model, device, batch_size, verbose)
|
101 |
+
return encoder
|
102 |
+
|
103 |
+
|
104 |
+
def get_sbert_encoder(model_name: str) -> SentenceTransformer:
|
105 |
+
"""
|
106 |
+
Get an instance of SentenceTransformer encoder based on the specified model name.
|
107 |
+
|
108 |
+
Args:
|
109 |
+
model_name (str): Name of the model to instantiate. You can use any model on Huggingface/SentenceTransformer
|
110 |
+
that is supported by SentenceTransformer.
|
111 |
+
|
112 |
+
Returns:
|
113 |
+
SentenceTransformer: Instance of the selected encoder based on the model_name.
|
114 |
+
|
115 |
+
Raises:
|
116 |
+
EnvironmentError: If an unsupported model_name is provided.
|
117 |
+
RuntimeError: If there's an issue during instantiation of the encoder.
|
118 |
+
"""
|
119 |
+
|
120 |
+
try:
|
121 |
+
encoder = SentenceTransformer(model_name, trust_remote_code=True)
|
122 |
+
except EnvironmentError as err:
|
123 |
+
raise EnvironmentError(str(err)) from None
|
124 |
+
except Exception as err:
|
125 |
+
raise RuntimeError(str(err)) from None
|
126 |
+
|
127 |
+
return encoder
|
128 |
+
|
129 |
+
|
requirements.txt
CHANGED
@@ -1 +1,3 @@
|
|
1 |
-
git+https://github.com/huggingface/evaluate@main
|
|
|
|
|
|
1 |
+
git+https://github.com/huggingface/evaluate@main
|
2 |
+
nltk
|
3 |
+
sentence-transformers
|
semncg.py
CHANGED
@@ -11,55 +11,340 @@
|
|
11 |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
# See the License for the specific language governing permissions and
|
13 |
# limitations under the License.
|
14 |
-
"""
|
15 |
|
|
|
16 |
import evaluate
|
17 |
import datasets
|
|
|
|
|
|
|
18 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
19 |
|
20 |
-
# TODO: Add BibTeX citation
|
21 |
_CITATION = """\
|
22 |
-
@
|
23 |
-
title =
|
24 |
-
|
25 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
26 |
}
|
27 |
"""
|
28 |
|
29 |
-
# TODO: Add description of the module here
|
30 |
_DESCRIPTION = """\
|
31 |
-
|
|
|
|
|
|
|
|
|
32 |
"""
|
33 |
|
34 |
-
|
35 |
-
# TODO: Add description of the arguments of the module here
|
36 |
_KWARGS_DESCRIPTION = """
|
37 |
-
|
|
|
|
|
|
|
38 |
Args:
|
39 |
-
predictions:
|
40 |
-
|
41 |
-
|
42 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
43 |
Returns:
|
44 |
-
|
45 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
46 |
Examples:
|
47 |
-
Examples should be written in doctest format, and should illustrate how
|
48 |
-
to use the function.
|
49 |
|
50 |
-
>>>
|
51 |
-
>>>
|
52 |
-
>>>
|
53 |
-
|
|
|
|
|
|
|
54 |
"""
|
55 |
|
56 |
-
|
57 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
58 |
|
59 |
|
60 |
@evaluate.utils.file_utils.add_start_docstrings(_DESCRIPTION, _KWARGS_DESCRIPTION)
|
61 |
class SemnCG(evaluate.Metric):
|
62 |
-
"""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
63 |
|
64 |
def _info(self):
|
65 |
# TODO: Specifies the evaluate.EvaluationModuleInfo object
|
@@ -70,26 +355,171 @@ class SemnCG(evaluate.Metric):
|
|
70 |
citation=_CITATION,
|
71 |
inputs_description=_KWARGS_DESCRIPTION,
|
72 |
# This defines the format of each prediction and reference
|
73 |
-
features=
|
74 |
-
|
75 |
-
|
76 |
-
|
77 |
-
|
78 |
-
|
79 |
-
|
80 |
-
|
81 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
82 |
)
|
83 |
|
84 |
def _download_and_prepare(self, dl_manager):
|
85 |
"""Optional: download external resources useful to compute the scores"""
|
86 |
-
|
87 |
-
|
88 |
-
|
89 |
-
|
90 |
-
|
91 |
-
|
92 |
-
|
93 |
-
|
94 |
-
|
95 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
11 |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
# See the License for the specific language governing permissions and
|
13 |
# limitations under the License.
|
14 |
+
"""Sem-NCG metric"""
|
15 |
|
16 |
+
from dataclasses import dataclass
|
17 |
import evaluate
|
18 |
import datasets
|
19 |
+
import re
|
20 |
+
import statistics
|
21 |
+
from typing import Dict, List, Tuple, Union
|
22 |
|
23 |
+
import nltk
|
24 |
+
import numpy as np
|
25 |
+
from sklearn.metrics.pairwise import cosine_similarity
|
26 |
+
from tqdm import tqdm
|
27 |
+
|
28 |
+
from .encoder_models import get_sbert_encoder, get_encoder
|
29 |
+
from .type_aliases import DEVICE_TYPE, NDArray, DOCUMENT_TYPE
|
30 |
+
from .utils import get_gpu, prep_sentences, flatten_list, slice_embeddings, is_nested_list_of_type, tokenize_and_prep_document
|
31 |
|
|
|
32 |
_CITATION = """\
|
33 |
+
@inproceedings{akter-etal-2022-revisiting,
|
34 |
+
title = "Revisiting Automatic Evaluation of Extractive Summarization Task: Can We Do Better than {ROUGE}?",
|
35 |
+
author = "Akter, Mousumi and
|
36 |
+
Bansal, Naman and
|
37 |
+
Karmaker, Shubhra Kanti",
|
38 |
+
editor = "Muresan, Smaranda and
|
39 |
+
Nakov, Preslav and
|
40 |
+
Villavicencio, Aline",
|
41 |
+
booktitle = "Findings of the Association for Computational Linguistics: ACL 2022",
|
42 |
+
month = may,
|
43 |
+
year = "2022",
|
44 |
+
address = "Dublin, Ireland",
|
45 |
+
publisher = "Association for Computational Linguistics",
|
46 |
+
url = "https://aclanthology.org/2022.findings-acl.122",
|
47 |
+
doi = "10.18653/v1/2022.findings-acl.122",
|
48 |
+
pages = "1547--1560",
|
49 |
+
abstract = "It has been the norm for a long time to evaluate automated summarization tasks using the popular ROUGE metric. Although several studies in the past have highlighted the limitations of ROUGE, researchers have struggled to reach a consensus on a better alternative until today. One major limitation of the traditional ROUGE metric is the lack of semantic understanding (relies on direct overlap of n-grams). In this paper, we exclusively focus on the extractive summarization task and propose a semantic-aware nCG (normalized cumulative gain)-based evaluation metric (called Sem-nCG) for evaluating this task. One fundamental contribution of the paper is that it demonstrates how we can generate more reliable semantic-aware ground truths for evaluating extractive summarization tasks without any additional human intervention. To the best of our knowledge, this work is the first of its kind. We have conducted extensive experiments with this new metric using the widely used CNN/DailyMail dataset. Experimental results show that the new Sem-nCG metric is indeed semantic-aware, shows higher correlation with human judgement (more reliable) and yields a large number of disagreements with the original ROUGE metric (suggesting that ROUGE often leads to inaccurate conclusions also verified by humans).",
|
50 |
}
|
51 |
"""
|
52 |
|
|
|
53 |
_DESCRIPTION = """\
|
54 |
+
Sem-nCG (Semantic Normalized Cumulative Gain) Metric evaluates the quality of predicted sentences
|
55 |
+
(abstractive/extractive) in relation to reference sentences and documents using Semantic Normalized Cumulative Gain
|
56 |
+
(NCG). It computes gain values and NCG scores based on cosine similarity between sentence embeddings, leveraging a
|
57 |
+
Sentence-BERT encoder. This metric is designed to assess the relevance and ranking of predicted sentences, making it
|
58 |
+
useful for tasks such as summarization and information retrieval.
|
59 |
"""
|
60 |
|
|
|
|
|
61 |
_KWARGS_DESCRIPTION = """
|
62 |
+
Sem-nCG (Semantic Normalized Cumulative Gain) compares the system-generated summaries (predictions) with ground truth
|
63 |
+
reference summaries (references) and input documents (documents) using Semantic Normalized Cumulative Gain (NCG).
|
64 |
+
It computes gain values and NCG scores based on sentence embeddings.
|
65 |
+
|
66 |
Args:
|
67 |
+
predictions (DOCUMENT_TYPE): The predicted sentences.
|
68 |
+
`tokenize_sentences`=True -> predictions: List[str]
|
69 |
+
`tokenize_sentences`=False -> predictions: List[List[str]]
|
70 |
+
references (DOCUMENT_TYPE): The reference sentences.
|
71 |
+
`tokenize_sentences`=True -> references: List[str]
|
72 |
+
`tokenize_sentences`=False -> references: List[List[str]]
|
73 |
+
documents (DOCUMENT_TYPE): Input documents.
|
74 |
+
`tokenize_sentences`=True -> documents: List[str]
|
75 |
+
`tokenize_sentences`=False -> documents: List[List[str]]
|
76 |
+
k (int): The rank threshold used for evaluating gains (typically top-k sentences). Default is 3.
|
77 |
+
gpu (Union[bool, str, int, List[Union[str, int]]]): Whether to use GPU or CPU for computation.
|
78 |
+
bool -
|
79 |
+
False - CPU (Default)
|
80 |
+
True - GPU (device 0) if gpu is available else CPU
|
81 |
+
int -
|
82 |
+
n - GPU, device index n
|
83 |
+
str -
|
84 |
+
'cuda', 'gpu', 'cpu'
|
85 |
+
List[Union[str, int]] - Multiple GPUs/cpus i.e. use multiple processes when computing embeddings
|
86 |
+
batch_size (int): Batch size for encoding. Default is 32.
|
87 |
+
verbose (bool): Flag to indicate verbose output. Default is False.
|
88 |
+
tokenize_sentences (bool): Flag to indicate whether to tokenize the sentences in the input documents. Default: True.
|
89 |
+
pre_compute_embeddings (bool): Flag to indicate whether to pre-compute embeddings for all sentences. This speeds up
|
90 |
+
computation but requires more memory. Default is False.
|
91 |
+
debug (bool): Flag to return detailed debug information including ranked gains. Default is False.
|
92 |
+
|
93 |
Returns:
|
94 |
+
Union[Tuple[float, List[float]], Tuple[float, List[RankedGains]]]:
|
95 |
+
If `debug` is False, returns a tuple containing the mean SemnCG score and a list of SemnCG scores for each document.
|
96 |
+
If `debug` is True, returns a tuple containing the mean SemnCG score and a list of `RankedGains` objects with
|
97 |
+
detailed gain information for each document.
|
98 |
+
|
99 |
+
Examples of input formats:
|
100 |
+
|
101 |
+
Case 1: tokenize_sentences = True
|
102 |
+
predictions: List[str] - List of predictions where each prediction is a document.
|
103 |
+
references: List[str] - List of references where each reference is a document.
|
104 |
+
documents: List[str] - List of input documents where each document is a document.
|
105 |
+
Example:
|
106 |
+
predictions = ["This is a prediction sentence 1. This is a prediction sentence 2."]
|
107 |
+
references = ["This is a reference sentence 1. This is a reference sentence 2."]
|
108 |
+
documents = ["This is a document sentence 1. This is a document sentence 2."]
|
109 |
+
|
110 |
+
Case 2: tokenize_sentences = False
|
111 |
+
predictions: List[List[str]] - List of predictions where each prediction is a list of sentences.
|
112 |
+
references: List[List[str]] - List of references where each reference is a list of sentences.
|
113 |
+
documents: List[List[str]] - List of input documents where each document is a list of sentences.
|
114 |
+
Example:
|
115 |
+
predictions = [["This is a prediction sentence 1.", "This is a prediction sentence 2."]]
|
116 |
+
references = [["This is a reference sentence 1.", "This is a reference sentence 2."]]
|
117 |
+
documents = [["This is a document sentence 1.", "This is a document sentence 2."]]
|
118 |
+
|
119 |
Examples:
|
|
|
|
|
120 |
|
121 |
+
>>> import evaluate
|
122 |
+
>>> predictions = ["This is a prediction sentence 1. This is a prediction sentence 2."]
|
123 |
+
>>> references = ["This is a reference sentence 1. This is a reference sentence 2."]
|
124 |
+
>>> documents = ["This is a document sentence 1. This is a document sentence 2."]
|
125 |
+
>>> metric = evaluate.load("nbansal/semncg", model_name="all-MiniLM-L6-v2")
|
126 |
+
>>> mean_score, scores = metric.compute(predictions=predictions, references=references, documents=documents)
|
127 |
+
>>> print(f"Mean SemnCG: {mean_score}")
|
128 |
"""
|
129 |
|
130 |
+
|
131 |
+
|
132 |
+
|
133 |
+
@dataclass
|
134 |
+
class RankedGains:
|
135 |
+
"""
|
136 |
+
Dataclass to store ranked gains and associated metadata.
|
137 |
+
|
138 |
+
Attributes:
|
139 |
+
gt_gains (List[Tuple[str, float]]): List of tuples representing ground truth (ideal) gains,
|
140 |
+
where each tuple contains a document sentence and its corresponding gain value.
|
141 |
+
pred_gains (List[Tuple[str, float]]): List of tuples representing predicted gains by the model,
|
142 |
+
where each tuple contains a document identifier and its corresponding gain value.
|
143 |
+
k (int): The rank threshold used for evaluating gains (typically top-k documents).
|
144 |
+
ncg (float): Normalized Cumulative Gain (NCG) score calculated based on the predicted gains
|
145 |
+
compared to the ground truth gains.
|
146 |
+
|
147 |
+
Notes:
|
148 |
+
- `gt_gains` and `pred_gains` are typically sorted in descending order
|
149 |
+
- `k` specifies the top-k threshold used for evaluating the gains.
|
150 |
+
- `ncg` provides a normalized measure of the model's performance.
|
151 |
+
"""
|
152 |
+
gt_gains: List[Tuple[str, float]]
|
153 |
+
pred_gains: List[Tuple[str, float]]
|
154 |
+
k: int
|
155 |
+
ncg: float
|
156 |
+
|
157 |
+
|
158 |
+
def compute_cosine_similarity(doc_embeds: NDArray, ref_embeds: NDArray) -> List[float]:
|
159 |
+
"""
|
160 |
+
Compute cosine similarity scores between each document embedding and reference embeddings.
|
161 |
+
|
162 |
+
Args:
|
163 |
+
doc_embeds (NDArray): 2D array of shape (#Docs, Embedding_dim) containing document embeddings.
|
164 |
+
ref_embeds (NDArray): 2D array of shape (#Refs, Embedding_dim) containing reference embeddings.
|
165 |
+
|
166 |
+
Returns:
|
167 |
+
List[float]: A list of mean cosine similarity scores between each document and reference embeddings.
|
168 |
+
The length of the list is equal to the number of documents (#Docs).
|
169 |
+
|
170 |
+
Notes:
|
171 |
+
- Uses cosine_similarity function from sklearn.metrics.pairwise to compute pairwise cosine similarities.
|
172 |
+
- Returns the mean cosine similarity scores across reference embeddings for each document embedding.
|
173 |
+
"""
|
174 |
+
# Compute cosine similarity between predicted and reference embeddings
|
175 |
+
cosine_scores = cosine_similarity(doc_embeds, ref_embeds) # [#Docs, #Refs]
|
176 |
+
return np.mean(cosine_scores, axis=1).tolist()
|
177 |
+
|
178 |
+
|
179 |
+
def compute_gain(sim_scores: List[float]) -> List[Tuple[int, float]]:
|
180 |
+
"""
|
181 |
+
Compute gain values for ranked similarity scores.
|
182 |
+
|
183 |
+
Args:
|
184 |
+
sim_scores (List[float]): List of similarity scores for documents (`compute_cosine_similarity(doc_embeds, ref_embeds)`)
|
185 |
+
|
186 |
+
Returns:
|
187 |
+
List[Tuple[int, float]]: A list of tuples where each tuple contains a document index and its corresponding gain
|
188 |
+
value. The list is sorted by descending order of gain values.
|
189 |
+
|
190 |
+
Notes:
|
191 |
+
- Computes gain values based on the rank order of similarity scores, where higher scores indicate higher gains.
|
192 |
+
- Uses the formula: gain = rank_position / sum of ranks, where rank_position starts from 1 for the highest score
|
193 |
+
- Returns a list sorted by descending gain values.
|
194 |
+
"""
|
195 |
+
count = len(sim_scores)
|
196 |
+
sim_scores = np.array(sim_scores).argsort()[::-1] # Reverse Sorted Order of doc sentence indices
|
197 |
+
denominator = count * (count + 1) / 2 # (n * (n+1))/2
|
198 |
+
return [(s_idx, val / denominator) for s_idx, val in zip(sim_scores, range(count, 0, -1))]
|
199 |
+
|
200 |
+
|
201 |
+
def score_ncg(model_relevance: List[float], gt_relevance: List[float]) -> float:
|
202 |
+
"""
|
203 |
+
Calculate the Normalized Cumulative Gain (NCG) score based on model relevance and ground truth relevance.
|
204 |
+
|
205 |
+
Args:
|
206 |
+
model_relevance (List[float]): List of gain values representing the relevance scores predicted by the model.
|
207 |
+
gt_relevance (List[float]): List of gain values representing the ground truth (ideal) relevance scores.
|
208 |
+
|
209 |
+
Returns:
|
210 |
+
float: Normalized Cumulative Gain (NCG) score, which measures the effectiveness of the model's relevance
|
211 |
+
predictions compared to the ideal relevance scores. The score ranges from 0 to 1, where higher values
|
212 |
+
indicate better performance.
|
213 |
+
|
214 |
+
Notes:
|
215 |
+
- Calculates Cumulative Gain (CG) for both model and ground truth relevance lists.
|
216 |
+
- Normalizes CG scores by dividing model CG by ground truth CG to get the NCG score.
|
217 |
+
- Returns 0 if the ground truth CG (icg) is 0 to avoid division by zero.
|
218 |
+
"""
|
219 |
+
|
220 |
+
# CG score
|
221 |
+
cg = sum(model_relevance)
|
222 |
+
|
223 |
+
# ICG score
|
224 |
+
icg = sum(gt_relevance)
|
225 |
+
|
226 |
+
# Normalized CG score
|
227 |
+
return cg / icg if icg != 0 else 0
|
228 |
+
|
229 |
+
|
230 |
+
def compute_ncg(pred_gains: List[Tuple[int, float]], gt_gains: List[Tuple[int, float]], k: int) -> float:
|
231 |
+
"""
|
232 |
+
Compute the Normalized Cumulative Gain (NCG) score based on predicted and ground truth gains up to rank k.
|
233 |
+
|
234 |
+
Args:
|
235 |
+
pred_gains (List[Tuple[int, float]]): List of tuples representing predicted gains by the model,
|
236 |
+
where each tuple contains a document position (or index) and its corresponding gain value.
|
237 |
+
(Sorted in Descending Order)
|
238 |
+
gt_gains (List[Tuple[int, float]]): List of tuples representing ground truth gains (ideal gains),
|
239 |
+
where each tuple contains a document position (or index) and its corresponding gain value.
|
240 |
+
(Sorted in Descending Order)
|
241 |
+
k (int): The rank threshold used for evaluating gains (typically top-k documents).
|
242 |
+
|
243 |
+
Returns:
|
244 |
+
float: Normalized Cumulative Gain (NCG) score based on the predicted gains compared to the ground truth gains.
|
245 |
+
|
246 |
+
Notes:
|
247 |
+
- Both `pred_gains` and `gt_gains` should be sorted lists (in descending order) where higher gain values indicate
|
248 |
+
higher relevance.
|
249 |
+
- The function calculates NCG up to rank `k`, considering only the top-k documents.
|
250 |
+
- Uses the `score_ncg` function to compute the NCG score based on the model's predicted gains and the ground
|
251 |
+
truth.
|
252 |
+
"""
|
253 |
+
gt_dict = dict(gt_gains)
|
254 |
+
gt_rel = [v for _, v in gt_gains[:k]]
|
255 |
+
model_rel = [gt_dict[position] for position, _ in pred_gains[:k]]
|
256 |
+
return score_ncg(model_rel, gt_rel)
|
257 |
+
|
258 |
+
|
259 |
+
def _validate_input_format(
|
260 |
+
tokenize_sentences: bool,
|
261 |
+
predictions: DOCUMENT_TYPE,
|
262 |
+
references: DOCUMENT_TYPE,
|
263 |
+
documents: DOCUMENT_TYPE
|
264 |
+
):
|
265 |
+
"""
|
266 |
+
Validate the format of predictions, references, and documents based on specified criteria.
|
267 |
+
|
268 |
+
Args:
|
269 |
+
tokenize_sentences (bool): Flag indicating whether sentences should be tokenized.
|
270 |
+
predictions (DOCUMENT_TYPE): Predictions to validate.
|
271 |
+
references (DOCUMENT_TYPE): References to validate.
|
272 |
+
documents (DOCUMENT_TYPE): Documents to validate.
|
273 |
+
|
274 |
+
Raises:
|
275 |
+
ValueError: If the format of predictions, references, or documents does not meet the specified criteria.
|
276 |
+
|
277 |
+
Validation Criteria:
|
278 |
+
The function validates predictions, references, and documents based on the following conditions:
|
279 |
+
1. If `tokenize_sentences` is True:
|
280 |
+
- Predictions, references, and documents must all be lists of strings (`is_list_of_strings_at_depth(obj, 1)`).
|
281 |
+
|
282 |
+
2. If `tokenize_sentences` is False:
|
283 |
+
- Predictions, references, and documents must all be lists of lists of strings
|
284 |
+
(`is_list_of_strings_at_depth(obj, 2)`).
|
285 |
+
|
286 |
+
The function checks these conditions and raises a ValueError if any condition is not met,
|
287 |
+
indicating that predictions, references, or documents are not in the valid input format.
|
288 |
+
|
289 |
+
Notes:
|
290 |
+
- `DOCUMENT_TYPE`: Union[List[str], List[List[str]]]
|
291 |
+
- Uses helper function `is_list_of_strings_at_depth` to validate the format of lists of strings.
|
292 |
+
|
293 |
+
Example:
|
294 |
+
>>> tokenize_sentences = True
|
295 |
+
>>> predictions = ["This is prediction 1.", "This is prediction 2."]
|
296 |
+
>>> references = ["Reference for prediction 1.", "Reference for prediction 2."]
|
297 |
+
>>> documents = ["Document 1 content.", "Document 2 content."]
|
298 |
+
>>> _validate_input_format(tokenize_sentences, predictions, references, documents)
|
299 |
+
|
300 |
+
Example:
|
301 |
+
>>> tokenize_sentences = False
|
302 |
+
>>> predictions = [["Sentence 1 in prediction 1.", "Sentence 2 in prediction 1."],
|
303 |
+
>>> ["Sentence 1 in prediction 2.", "Sentence 2 in prediction 2."]]
|
304 |
+
>>> references = [["Sentences in reference 1."], ["Sentences in reference 2."]]
|
305 |
+
>>> documents = [["Sentence 1 in document 1.", "Sentence 2 in document 1."],
|
306 |
+
>>> ["Sentence 1 in document 2.", "Sentence 2 in document 2."]]
|
307 |
+
>>> _validate_input_format(tokenize_sentences, predictions, references, documents)
|
308 |
+
"""
|
309 |
+
if not (len(predictions) == len(references) == len(documents)):
|
310 |
+
raise ValueError("Predictions, References and Documents must have the same length.")
|
311 |
+
|
312 |
+
if len(predictions) == 0:
|
313 |
+
raise ValueError("Can't have empty inputs")
|
314 |
+
|
315 |
+
def is_list_of_strings_at_depth(lst_obj, depth: int):
|
316 |
+
return is_nested_list_of_type(lst_obj, element_type=str, depth=depth)
|
317 |
+
|
318 |
+
if tokenize_sentences:
|
319 |
+
condition = (
|
320 |
+
is_list_of_strings_at_depth(predictions, 1) and
|
321 |
+
is_list_of_strings_at_depth(references, 1) and
|
322 |
+
is_list_of_strings_at_depth(documents, 1)
|
323 |
+
)
|
324 |
+
else:
|
325 |
+
condition = (
|
326 |
+
is_list_of_strings_at_depth(predictions, 2) and
|
327 |
+
is_list_of_strings_at_depth(references, 2) and
|
328 |
+
is_list_of_strings_at_depth(documents, 2)
|
329 |
+
)
|
330 |
+
|
331 |
+
if not condition:
|
332 |
+
raise ValueError("Predictions, References and Documents are not valid input format. Refer to documentation.")
|
333 |
|
334 |
|
335 |
@evaluate.utils.file_utils.add_start_docstrings(_DESCRIPTION, _KWARGS_DESCRIPTION)
|
336 |
class SemnCG(evaluate.Metric):
|
337 |
+
"""
|
338 |
+
SemnCG (Semantic Normalized Cumulative Gain) Metric.
|
339 |
+
|
340 |
+
This metric evaluates the quality of predicted sentences in relation to reference sentences and documents
|
341 |
+
using Semantic Normalized Cumulative Gain (NCG). It computes the gain values and NCG scores based on
|
342 |
+
cosine similarity between sentence embeddings, leveraging a Sentence-BERT encoder.
|
343 |
+
"""
|
344 |
+
|
345 |
+
def __init__(self, model_name: str = "all-MiniLM-L6-v2", **kwargs):
|
346 |
+
self.sbert_encoder = get_sbert_encoder(model_name)
|
347 |
+
super().__init__(**kwargs)
|
348 |
|
349 |
def _info(self):
|
350 |
# TODO: Specifies the evaluate.EvaluationModuleInfo object
|
|
|
355 |
citation=_CITATION,
|
356 |
inputs_description=_KWARGS_DESCRIPTION,
|
357 |
# This defines the format of each prediction and reference
|
358 |
+
features=[
|
359 |
+
# Tokenize_Sentences = True
|
360 |
+
datasets.Features(
|
361 |
+
{
|
362 |
+
"predictions": datasets.Value("string"),
|
363 |
+
"references": datasets.Value("string"),
|
364 |
+
"documents": datasets.Value("string"),
|
365 |
+
}
|
366 |
+
),
|
367 |
+
# Tokenize_Sentences = False
|
368 |
+
datasets.Features(
|
369 |
+
{
|
370 |
+
"predictions": datasets.Sequence(datasets.Value("string", id="sequence"), id="predictions"),
|
371 |
+
"references": datasets.Sequence(datasets.Value("string", id="sequence"), id="references"),
|
372 |
+
"documents": datasets.Sequence(datasets.Value("string", id="sequence"), id="documents"),
|
373 |
+
}
|
374 |
+
),
|
375 |
+
],
|
376 |
+
# # Homepage of the module for documentation
|
377 |
+
# homepage="http://module.homepage",
|
378 |
+
# # Additional links to the codebase or references
|
379 |
+
# codebase_urls=["http://github.com/path/to/codebase/of/new_module"],
|
380 |
+
reference_urls=["https://aclanthology.org/2022.findings-acl.122/"]
|
381 |
)
|
382 |
|
383 |
def _download_and_prepare(self, dl_manager):
|
384 |
"""Optional: download external resources useful to compute the scores"""
|
385 |
+
nltk.download("punkt", quiet=True)
|
386 |
+
|
387 |
+
def _compute(
|
388 |
+
self,
|
389 |
+
predictions: DOCUMENT_TYPE,
|
390 |
+
references: DOCUMENT_TYPE,
|
391 |
+
documents: DOCUMENT_TYPE,
|
392 |
+
k: int = 3,
|
393 |
+
gpu: DEVICE_TYPE = False,
|
394 |
+
verbose: bool = False,
|
395 |
+
batch_size: int = 32,
|
396 |
+
tokenize_sentences: bool = True,
|
397 |
+
pre_compute_embeddings: bool = False,
|
398 |
+
debug: bool = False,
|
399 |
+
) -> Union[Tuple[float, List[float]], Tuple[float, List[RankedGains]]]:
|
400 |
+
"""
|
401 |
+
Compute the Semantic Normalized Cumulative Gain (SemnCG) score.
|
402 |
+
|
403 |
+
Args:
|
404 |
+
predictions (DOCUMENT_TYPE): The predicted sentences.
|
405 |
+
`tokenize_sentences`=True -> predictions: List[str]
|
406 |
+
`tokenize_sentences`=False -> predictions: List[List[str]]
|
407 |
+
references (DOCUMENT_TYPE): The reference sentences.
|
408 |
+
`tokenize_sentences`=True -> references: List[str]
|
409 |
+
`tokenize_sentences`=False -> references: List[List[str]]
|
410 |
+
documents (DOCUMENT_TYPE): Input documents.
|
411 |
+
`tokenize_sentences`=True -> references: List[str]
|
412 |
+
`tokenize_sentences`=False -> references: List[List[str]]
|
413 |
+
k (int, optional): The rank threshold used for evaluating gains (typically top-k sentences). Default is 3.
|
414 |
+
gpu (DEVICE_TYPE, optional): Whether to use GPU for computation. Default is False.
|
415 |
+
verbose (bool, optional): Whether to print verbose logs. Default is False.
|
416 |
+
batch_size (int, optional): The batch size for encoding sentences. Default is 32.
|
417 |
+
tokenize_sentences (bool, optional): Whether to tokenize sentences. If True, sentences are tokenized before
|
418 |
+
processing. Default is True.
|
419 |
+
pre_compute_embeddings (bool, optional): Whether to pre-compute embeddings for all sentences. This speeds up
|
420 |
+
computation but requires more memory. Default is False.
|
421 |
+
debug (bool, optional): Whether to return detailed debug information including ranked gains. Default=False.
|
422 |
+
|
423 |
+
Returns:
|
424 |
+
Union[Tuple[float, List[float]], Tuple[float, List[RankedGains]]]:
|
425 |
+
If `debug` is False, returns a tuple containing the mean SemnCG score and a list of SemnCG scores for each document.
|
426 |
+
If `debug` is True, returns a tuple containing the mean SemnCG score and a list of `RankedGains` objects with detailed gain information for each document.
|
427 |
+
|
428 |
+
Raises:
|
429 |
+
ValueError: If the format of predictions, references, or documents does not meet the specified criteria.
|
430 |
+
|
431 |
+
Notes:
|
432 |
+
- Validates the format of predictions, references, and documents based on `tokenize_sentences`.
|
433 |
+
- Computes embeddings using a Sentence-BERT encoder.
|
434 |
+
- Computes cosine similarity between document, reference, and prediction embeddings.
|
435 |
+
- Calculates gain values and Normalized Cumulative Gain (NCG) scores.
|
436 |
+
- Optionally returns detailed debug information for each document if `debug` is True.
|
437 |
+
"""
|
438 |
+
|
439 |
+
# Validate inputs corresponding to flags
|
440 |
+
_validate_input_format(tokenize_sentences, predictions, references, documents)
|
441 |
+
|
442 |
+
# Get GPU
|
443 |
+
device = get_gpu(gpu)
|
444 |
+
if verbose:
|
445 |
+
print(f"Using devices: {device}")
|
446 |
+
|
447 |
+
# Get model
|
448 |
+
encoder = get_encoder(self.sbert_encoder, device=device, batch_size=batch_size, verbose=verbose)
|
449 |
+
|
450 |
+
if pre_compute_embeddings: # fast but takes more memory
|
451 |
+
predictions = [tokenize_and_prep_document(pred, tokenize_sentences) for pred in predictions]
|
452 |
+
references = [tokenize_and_prep_document(ref, tokenize_sentences) for ref in references]
|
453 |
+
documents = [tokenize_and_prep_document(doc, tokenize_sentences) for doc in documents]
|
454 |
+
|
455 |
+
# This is only done for debug case
|
456 |
+
sent_tokenized_documents = documents
|
457 |
+
|
458 |
+
# Compute All Embeddings
|
459 |
+
all_sentences = flatten_list(documents) + flatten_list(references) + flatten_list(predictions)
|
460 |
+
embeddings = encoder.encode(all_sentences)
|
461 |
+
|
462 |
+
prediction_sentences_count = [len(pred) for pred in predictions]
|
463 |
+
reference_sentences_count = [len(ref) for ref in references]
|
464 |
+
document_sentences_count = [len(doc) for doc in documents]
|
465 |
+
|
466 |
+
# Get embeddings corresponding to documents, references and predictions (IN ORDER)
|
467 |
+
doc_embeddings = slice_embeddings(embeddings, document_sentences_count)
|
468 |
+
ref_embeddings = slice_embeddings(embeddings[sum(document_sentences_count):], reference_sentences_count)
|
469 |
+
pred_embeddings = slice_embeddings(
|
470 |
+
embeddings[sum(document_sentences_count+reference_sentences_count):], prediction_sentences_count
|
471 |
+
)
|
472 |
+
|
473 |
+
iterable_obj = zip(pred_embeddings, ref_embeddings, doc_embeddings)
|
474 |
+
|
475 |
+
else:
|
476 |
+
iterable_obj = zip(predictions, references, documents)
|
477 |
+
|
478 |
+
out = []
|
479 |
+
for idx, (pred, ref, doc) in enumerate(tqdm(iterable_obj)):
|
480 |
+
|
481 |
+
if not pre_compute_embeddings: # Compute embeddings
|
482 |
+
ref_sentences = tokenize_and_prep_document(ref, tokenize_sentences)
|
483 |
+
pred_sentences = tokenize_and_prep_document(pred, tokenize_sentences)
|
484 |
+
doc_sentences = tokenize_and_prep_document(doc, tokenize_sentences)
|
485 |
+
|
486 |
+
# Compute Embeddings
|
487 |
+
doc_sentence_count = len(doc_sentences)
|
488 |
+
ref_sentence_count = len(ref_sentences)
|
489 |
+
all_sentences = doc_sentences + ref_sentences + pred_sentences
|
490 |
+
embeddings = encoder.encode(all_sentences)
|
491 |
+
doc_embeddings = embeddings[:doc_sentence_count]
|
492 |
+
ref_embeddings = embeddings[doc_sentence_count:doc_sentence_count + ref_sentence_count]
|
493 |
+
pred_embeddings = embeddings[doc_sentence_count + ref_sentence_count:]
|
494 |
+
else: # we already have embeddings
|
495 |
+
doc_embeddings = doc
|
496 |
+
ref_embeddings = ref
|
497 |
+
pred_embeddings = pred
|
498 |
+
|
499 |
+
doc_sentences = sent_tokenized_documents[idx]
|
500 |
+
|
501 |
+
# Compute Pair-Wise Cosine Similarity
|
502 |
+
ref_sim_scores = compute_cosine_similarity(doc_embeddings, ref_embeddings)
|
503 |
+
pred_sim_scores = compute_cosine_similarity(doc_embeddings, pred_embeddings)
|
504 |
+
|
505 |
+
# Compute Gains
|
506 |
+
ground_truth_gain = compute_gain(ref_sim_scores)
|
507 |
+
|
508 |
+
# this is used to compute top-predicted sentence indices
|
509 |
+
pred_gain = compute_gain(pred_sim_scores)
|
510 |
+
real_k = min(len(pred_gain), k)
|
511 |
+
|
512 |
+
# Compute NCG Scores
|
513 |
+
ncg_score = compute_ncg(pred_gain, ground_truth_gain, real_k)
|
514 |
+
|
515 |
+
if debug:
|
516 |
+
ground_truth_gain = [(doc_sentences[sent_idx], gain_val) for sent_idx, gain_val in ground_truth_gain]
|
517 |
+
pred_gain = [(doc_sentences[sent_idx], gain_val) for sent_idx, gain_val in pred_gain]
|
518 |
+
out.append(RankedGains(ground_truth_gain, pred_gain, k=real_k, ncg=ncg_score))
|
519 |
+
else:
|
520 |
+
out.append(ncg_score)
|
521 |
+
|
522 |
+
if debug:
|
523 |
+
return statistics.mean([ele.ncg for ele in out]), out
|
524 |
+
|
525 |
+
return statistics.mean(out), out
|
tests.py
CHANGED
@@ -1,17 +1,418 @@
|
|
1 |
-
|
2 |
-
|
3 |
-
|
4 |
-
|
5 |
-
|
6 |
-
|
7 |
-
|
8 |
-
|
9 |
-
|
10 |
-
|
11 |
-
|
12 |
-
|
13 |
-
|
14 |
-
|
15 |
-
|
16 |
-
|
17 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import statistics
|
2 |
+
import unittest
|
3 |
+
from unittest.mock import patch, MagicMock
|
4 |
+
|
5 |
+
import numpy as np
|
6 |
+
import torch
|
7 |
+
from numpy.testing import assert_almost_equal
|
8 |
+
from sentence_transformers import SentenceTransformer
|
9 |
+
from sklearn.metrics.pairwise import cosine_similarity
|
10 |
+
|
11 |
+
from .encoder_models import SBertEncoder, get_encoder, get_sbert_encoder
|
12 |
+
from .semncg import RankedGains, compute_cosine_similarity, compute_gain, score_ncg, compute_ncg, _validate_input_format, SemnCG
|
13 |
+
from .utils import get_gpu, slice_embeddings, is_nested_list_of_type, flatten_list, prep_sentences, tokenize_and_prep_document
|
14 |
+
|
15 |
+
|
16 |
+
class TestUtils(unittest.TestCase):
|
17 |
+
def test_get_gpu(self):
|
18 |
+
gpu_count = torch.cuda.device_count()
|
19 |
+
gpu_available = torch.cuda.is_available()
|
20 |
+
|
21 |
+
# Test single boolean input
|
22 |
+
self.assertEqual(get_gpu(True), 0 if gpu_available else "cpu")
|
23 |
+
self.assertEqual(get_gpu(False), "cpu")
|
24 |
+
|
25 |
+
# Test single string input
|
26 |
+
self.assertEqual(get_gpu("cpu"), "cpu")
|
27 |
+
self.assertEqual(get_gpu("gpu"), 0 if gpu_available else "cpu")
|
28 |
+
self.assertEqual(get_gpu("cuda"), 0 if gpu_available else "cpu")
|
29 |
+
|
30 |
+
# Test single integer input
|
31 |
+
self.assertEqual(get_gpu(0), 0 if gpu_available else "cpu")
|
32 |
+
self.assertEqual(get_gpu(1), 1 if gpu_available else "cpu")
|
33 |
+
|
34 |
+
# Test list input with unique elements
|
35 |
+
self.assertEqual(get_gpu([True, "cpu", 0]), [0, "cpu"] if gpu_available else ["cpu", "cpu", "cpu"])
|
36 |
+
|
37 |
+
# Test list input with duplicate elements
|
38 |
+
self.assertEqual(get_gpu([0, 0, "gpu"]), 0 if gpu_available else ["cpu", "cpu", "cpu"])
|
39 |
+
|
40 |
+
# Test list input with duplicate elements of different types
|
41 |
+
self.assertEqual(get_gpu([True, 0, "gpu"]), 0 if gpu_available else ["cpu", "cpu", "cpu"])
|
42 |
+
|
43 |
+
# Test list input but only one element
|
44 |
+
self.assertEqual(get_gpu([True]), 0 if gpu_available else "cpu")
|
45 |
+
|
46 |
+
# Test list input with all integers
|
47 |
+
self.assertEqual(get_gpu(list(range(gpu_count))),
|
48 |
+
list(range(gpu_count)) if gpu_available else gpu_count * ["cpu"])
|
49 |
+
|
50 |
+
with self.assertRaises(ValueError):
|
51 |
+
get_gpu("invalid")
|
52 |
+
|
53 |
+
with self.assertRaises(ValueError):
|
54 |
+
get_gpu(torch.cuda.device_count())
|
55 |
+
|
56 |
+
def test_prep_sentences(self):
|
57 |
+
# Test normal case
|
58 |
+
self.assertEqual(prep_sentences(["Hello, world!", " This is a test. ", "!!!"]),
|
59 |
+
['Hello, world!', 'This is a test.'])
|
60 |
+
|
61 |
+
# Test case with only punctuations
|
62 |
+
with self.assertRaises(ValueError):
|
63 |
+
prep_sentences(["!!!", "..."])
|
64 |
+
|
65 |
+
# Test case with empty list
|
66 |
+
with self.assertRaises(ValueError):
|
67 |
+
prep_sentences([])
|
68 |
+
|
69 |
+
def test_tokenize_and_prep_document(self):
|
70 |
+
# Test tokenize=True with string input
|
71 |
+
self.assertEqual(tokenize_and_prep_document("Hello, world! This is a test.", True),
|
72 |
+
['Hello, world!', 'This is a test.'])
|
73 |
+
|
74 |
+
# Test tokenize=False with list of strings input
|
75 |
+
self.assertEqual(tokenize_and_prep_document(["Hello, world!", "This is a test."], False),
|
76 |
+
['Hello, world!', 'This is a test.'])
|
77 |
+
|
78 |
+
# Test tokenize=True with empty document
|
79 |
+
with self.assertRaises(ValueError):
|
80 |
+
tokenize_and_prep_document("!!! ...", True)
|
81 |
+
|
82 |
+
def test_slice_embeddings(self):
|
83 |
+
# Case 1
|
84 |
+
embeddings = np.random.rand(10, 5)
|
85 |
+
num_sentences = [3, 2, 5]
|
86 |
+
expected_output = [embeddings[:3], embeddings[3:5], embeddings[5:]]
|
87 |
+
self.assertTrue(
|
88 |
+
all(np.array_equal(a, b) for a, b in zip(slice_embeddings(embeddings, num_sentences),
|
89 |
+
expected_output))
|
90 |
+
)
|
91 |
+
|
92 |
+
# Case 2
|
93 |
+
num_sentences_nested = [[2, 1], [3, 4]]
|
94 |
+
expected_output_nested = [[embeddings[:2], embeddings[2:3]], [embeddings[3:6], embeddings[6:]]]
|
95 |
+
self.assertTrue(
|
96 |
+
slice_embeddings(embeddings, num_sentences_nested), expected_output_nested
|
97 |
+
)
|
98 |
+
|
99 |
+
# Case 3
|
100 |
+
document_sentences_count = [10, 8, 7]
|
101 |
+
reference_sentences_count = [5, 3, 2]
|
102 |
+
pred_sentences_count = [2, 2, 1]
|
103 |
+
all_embeddings = np.random.rand(
|
104 |
+
sum(document_sentences_count + reference_sentences_count + pred_sentences_count), 5,
|
105 |
+
)
|
106 |
+
|
107 |
+
embeddings = all_embeddings
|
108 |
+
expected_doc_embeddings = [embeddings[:10], embeddings[10:18], embeddings[18:25]]
|
109 |
+
|
110 |
+
embeddings = all_embeddings[25:]
|
111 |
+
expected_ref_embeddings = [embeddings[:5], embeddings[5:8], embeddings[8:10]]
|
112 |
+
|
113 |
+
embeddings = all_embeddings[35:]
|
114 |
+
expected_pred_embeddings = [embeddings[:2], embeddings[2:4], embeddings[4:5]]
|
115 |
+
|
116 |
+
doc_embeddings = slice_embeddings(all_embeddings, document_sentences_count)
|
117 |
+
ref_embeddings = slice_embeddings(all_embeddings[sum(document_sentences_count):], reference_sentences_count)
|
118 |
+
pred_embeddings = slice_embeddings(
|
119 |
+
all_embeddings[sum(document_sentences_count+reference_sentences_count):], pred_sentences_count
|
120 |
+
)
|
121 |
+
|
122 |
+
self.assertTrue(doc_embeddings, expected_doc_embeddings)
|
123 |
+
self.assertTrue(ref_embeddings, expected_ref_embeddings)
|
124 |
+
self.assertTrue(pred_embeddings, expected_pred_embeddings)
|
125 |
+
|
126 |
+
with self.assertRaises(TypeError):
|
127 |
+
slice_embeddings(embeddings, "invalid")
|
128 |
+
|
129 |
+
def test_is_nested_list_of_type(self):
|
130 |
+
# Test case: Depth 0, single element matching element_type
|
131 |
+
self.assertTrue(is_nested_list_of_type("test", str, 0))
|
132 |
+
|
133 |
+
# Test case: Depth 0, single element not matching element_type
|
134 |
+
self.assertFalse(is_nested_list_of_type("test", int, 0))
|
135 |
+
|
136 |
+
# Test case: Depth 1, list of elements matching element_type
|
137 |
+
self.assertTrue(is_nested_list_of_type(["apple", "banana"], str, 1))
|
138 |
+
|
139 |
+
# Test case: Depth 1, list of elements not matching element_type
|
140 |
+
self.assertFalse(is_nested_list_of_type([1, 2, 3], str, 1))
|
141 |
+
|
142 |
+
# Test case: Depth 0 (Wrong), list of elements matching element_type
|
143 |
+
self.assertFalse(is_nested_list_of_type([1, 2, 3], str, 0))
|
144 |
+
|
145 |
+
# Depth 2
|
146 |
+
self.assertTrue(is_nested_list_of_type([[1, 2], [3, 4]], int, 2))
|
147 |
+
self.assertTrue(is_nested_list_of_type([['1', '2'], ['3', '4']], str, 2))
|
148 |
+
self.assertFalse(is_nested_list_of_type([[1, 2], ["a", "b"]], int, 2))
|
149 |
+
|
150 |
+
# Depth 3
|
151 |
+
self.assertFalse(is_nested_list_of_type([[[1], [2]], [[3], [4]]], list, 3))
|
152 |
+
self.assertTrue(is_nested_list_of_type([[[1], [2]], [[3], [4]]], int, 3))
|
153 |
+
|
154 |
+
with self.assertRaises(ValueError):
|
155 |
+
is_nested_list_of_type([1, 2], int, -1)
|
156 |
+
|
157 |
+
def test_flatten_list(self):
|
158 |
+
self.assertEqual(flatten_list([1, [2, 3], [[4], 5]]), [1, 2, 3, 4, 5])
|
159 |
+
self.assertEqual(flatten_list([]), [])
|
160 |
+
self.assertEqual(flatten_list([1, 2, 3]), [1, 2, 3])
|
161 |
+
self.assertEqual(flatten_list([[[[1]]]]), [1])
|
162 |
+
|
163 |
+
|
164 |
+
class TestSBertEncoder(unittest.TestCase):
|
165 |
+
|
166 |
+
def setUp(self) -> None:
|
167 |
+
# Set up a test SentenceTransformer model
|
168 |
+
self.model_name = "paraphrase-distilroberta-base-v1"
|
169 |
+
self.sbert_model = get_sbert_encoder(self.model_name)
|
170 |
+
self.device = "cpu" # For testing on CPU
|
171 |
+
self.batch_size = 32
|
172 |
+
self.verbose = False
|
173 |
+
self.encoder = SBertEncoder(self.sbert_model, self.device, self.batch_size, self.verbose)
|
174 |
+
|
175 |
+
def test_encode_single_sentence(self):
|
176 |
+
sentence = "Hello, world!"
|
177 |
+
embeddings = self.encoder.encode([sentence])
|
178 |
+
self.assertEqual(embeddings.shape, (1, 768)) # Adjust shape based on your model's embedding dimension
|
179 |
+
|
180 |
+
def test_encode_multiple_sentences(self):
|
181 |
+
sentences = ["Hello, world!", "This is a test."]
|
182 |
+
embeddings = self.encoder.encode(sentences)
|
183 |
+
self.assertEqual(embeddings.shape, (2, 768)) # Adjust shape based on your model's embedding dimension
|
184 |
+
|
185 |
+
def test_get_sbert_encoder(self):
|
186 |
+
model_name = "paraphrase-distilroberta-base-v1"
|
187 |
+
sbert_model = get_sbert_encoder(model_name)
|
188 |
+
self.assertIsInstance(sbert_model, SentenceTransformer)
|
189 |
+
|
190 |
+
def test_encode_with_gpu(self):
|
191 |
+
if torch.cuda.is_available():
|
192 |
+
device = "cuda"
|
193 |
+
encoder = get_encoder(self.sbert_model, device, self.batch_size, self.verbose)
|
194 |
+
sentences = ["Hello, world!", "This is a test."]
|
195 |
+
embeddings = encoder.encode(sentences)
|
196 |
+
self.assertEqual(embeddings.shape, (2, 768)) # Adjust shape based on your model's embedding dimension
|
197 |
+
else:
|
198 |
+
self.skipTest("CUDA not available, skipping GPU test.")
|
199 |
+
|
200 |
+
def test_encode_multi_device(self):
|
201 |
+
if torch.cuda.device_count() < 2:
|
202 |
+
self.skipTest("Multi-GPU test requires at least 2 GPUs.")
|
203 |
+
else:
|
204 |
+
devices = ["cuda:0", "cuda:1"]
|
205 |
+
encoder = get_encoder(self.sbert_model, devices, self.batch_size, self.verbose)
|
206 |
+
sentences = ["This is a test sentence.", "Here is another sentence.", "This is a test sentence."]
|
207 |
+
embeddings = encoder.encode(sentences)
|
208 |
+
self.assertIsInstance(embeddings, np.ndarray)
|
209 |
+
self.assertEqual(embeddings.shape[0], 3)
|
210 |
+
self.assertEqual(embeddings.shape[1], self.encoder.model.get_sentence_embedding_dimension())
|
211 |
+
|
212 |
+
|
213 |
+
class TestGetEncoder(unittest.TestCase):
|
214 |
+
def setUp(self):
|
215 |
+
self.device = "cuda" if torch.cuda.is_available() else "cpu"
|
216 |
+
self.batch_size = 8
|
217 |
+
self.verbose = False
|
218 |
+
|
219 |
+
def _base_test(self, model_name):
|
220 |
+
sbert_model = get_sbert_encoder(model_name)
|
221 |
+
encoder = get_encoder(sbert_model, self.device, self.batch_size, self.verbose)
|
222 |
+
|
223 |
+
# Assert
|
224 |
+
self.assertIsInstance(encoder, SBertEncoder)
|
225 |
+
self.assertEqual(encoder.device, self.device)
|
226 |
+
self.assertEqual(encoder.batch_size, self.batch_size)
|
227 |
+
self.assertEqual(encoder.verbose, self.verbose)
|
228 |
+
|
229 |
+
def test_get_sbert_encoder(self):
|
230 |
+
model_name = "stsb-roberta-large"
|
231 |
+
self._base_test(model_name)
|
232 |
+
|
233 |
+
def test_sbert_model(self):
|
234 |
+
model_name = "all-mpnet-base-v2"
|
235 |
+
self._base_test(model_name)
|
236 |
+
|
237 |
+
def test_huggingface_model(self):
|
238 |
+
"""Test Huggingface models which work with SBert library"""
|
239 |
+
model_name = "roberta-base"
|
240 |
+
self._base_test(model_name)
|
241 |
+
|
242 |
+
def test_get_encoder_environment_error(self): # This parameter is used when using patch decorator
|
243 |
+
model_name = "abc" # Wrong model_name
|
244 |
+
with self.assertRaises(EnvironmentError):
|
245 |
+
get_sbert_encoder(model_name)
|
246 |
+
|
247 |
+
def test_get_encoder_other_exception(self):
|
248 |
+
model_name = "apple/OpenELM-270M" # This model is not supported by SentenceTransformer lib
|
249 |
+
with self.assertRaises(RuntimeError):
|
250 |
+
get_sbert_encoder(model_name)
|
251 |
+
|
252 |
+
|
253 |
+
class TestRankedGainsDataclass(unittest.TestCase):
|
254 |
+
def test_ranked_gains_dataclass(self):
|
255 |
+
# Test initialization and attribute access
|
256 |
+
gt_gains = [("doc1", 0.8), ("doc2", 0.6)]
|
257 |
+
pred_gains = [("doc2", 0.7), ("doc1", 0.5)]
|
258 |
+
k = 2
|
259 |
+
ncg = 0.75
|
260 |
+
ranked_gains = RankedGains(gt_gains, pred_gains, k, ncg)
|
261 |
+
|
262 |
+
self.assertEqual(ranked_gains.gt_gains, gt_gains)
|
263 |
+
self.assertEqual(ranked_gains.pred_gains, pred_gains)
|
264 |
+
self.assertEqual(ranked_gains.k, k)
|
265 |
+
self.assertEqual(ranked_gains.ncg, ncg)
|
266 |
+
|
267 |
+
|
268 |
+
class TestComputeCosineSimilarity(unittest.TestCase):
|
269 |
+
def test_compute_cosine_similarity(self):
|
270 |
+
doc_embeds = np.array([[0.1, 0.2, 0.3], [0.4, 0.5, 0.6]])
|
271 |
+
ref_embeds = np.array([[0.2, 0.3, 0.4], [0.5, 0.6, 0.7]])
|
272 |
+
# Test compute_cosine_similarity function
|
273 |
+
similarity_scores = compute_cosine_similarity(doc_embeds, ref_embeds)
|
274 |
+
print(similarity_scores)
|
275 |
+
|
276 |
+
# Example values, change as per actual function output
|
277 |
+
expected_scores = [0.980, 0.997]
|
278 |
+
|
279 |
+
self.assertAlmostEqual(similarity_scores[0], expected_scores[0], places=3)
|
280 |
+
self.assertAlmostEqual(similarity_scores[1], expected_scores[1], places=3)
|
281 |
+
|
282 |
+
|
283 |
+
class TestComputeGain(unittest.TestCase):
|
284 |
+
def test_compute_gain(self):
|
285 |
+
# Test compute_gain function
|
286 |
+
sim_scores = [0.8, 0.6, 0.7]
|
287 |
+
gains = compute_gain(sim_scores)
|
288 |
+
print(gains)
|
289 |
+
|
290 |
+
# Example values, change as per actual function output
|
291 |
+
expected_gains = [(0, 0.5), (2, 0.3333333333333333), (1, 0.16666666666666666)]
|
292 |
+
|
293 |
+
self.assertEqual(gains, expected_gains)
|
294 |
+
|
295 |
+
|
296 |
+
class TestScoreNcg(unittest.TestCase):
|
297 |
+
def test_score_ncg(self):
|
298 |
+
# Test score_ncg function
|
299 |
+
model_relevance = [0.8, 0.7, 0.6]
|
300 |
+
gt_relevance = [1.0, 0.9, 0.8]
|
301 |
+
ncg_score = score_ncg(model_relevance, gt_relevance)
|
302 |
+
expected_ncg = 0.778 # Example value, change as per actual function output
|
303 |
+
|
304 |
+
self.assertAlmostEqual(ncg_score, expected_ncg, places=3)
|
305 |
+
|
306 |
+
|
307 |
+
class TestComputeNcg(unittest.TestCase):
|
308 |
+
def test_compute_ncg(self):
|
309 |
+
# Test compute_ncg function
|
310 |
+
pred_gains = [(0, 0.8), (2, 0.7), (1, 0.6)]
|
311 |
+
gt_gains = [(0, 1.0), (1, 0.9), (2, 0.8)]
|
312 |
+
k = 3
|
313 |
+
ncg_score = compute_ncg(pred_gains, gt_gains, k)
|
314 |
+
expected_ncg = 1.0 # TODO: Confirm this with Dr. Santu
|
315 |
+
|
316 |
+
self.assertAlmostEqual(ncg_score, expected_ncg, places=6)
|
317 |
+
|
318 |
+
|
319 |
+
class TestValidateInputFormat(unittest.TestCase):
|
320 |
+
def test_validate_input_format(self):
|
321 |
+
# Test _validate_input_format function
|
322 |
+
tokenize_sentences = True
|
323 |
+
predictions = ["Prediction 1", "Prediction 2"]
|
324 |
+
references = ["Reference 1", "Reference 2"]
|
325 |
+
documents = ["Document 1", "Document 2"]
|
326 |
+
|
327 |
+
# No exception should be raised for valid input
|
328 |
+
try:
|
329 |
+
_validate_input_format(tokenize_sentences, predictions, references, documents)
|
330 |
+
except ValueError as e:
|
331 |
+
self.fail(f"_validate_input_format raised ValueError unexpectedly: {str(e)}")
|
332 |
+
|
333 |
+
# Test invalid input format
|
334 |
+
predictions_invalid = [["Sentence 1 in prediction 1.", "Sentence 2 in prediction 1."],
|
335 |
+
["Sentence 1 in prediction 2.", "Sentence 2 in prediction 2."]]
|
336 |
+
references_invalid = [["Sentences in reference 1."], ["Sentences in reference 2."]]
|
337 |
+
documents_invalid = [["Sentence 1 in document 1.", "Sentence 2 in document 1."],
|
338 |
+
["Sentence 1 in document 2.", "Sentence 2 in document 2."]]
|
339 |
+
|
340 |
+
with self.assertRaises(ValueError):
|
341 |
+
_validate_input_format(tokenize_sentences, predictions_invalid, references, documents)
|
342 |
+
|
343 |
+
with self.assertRaises(ValueError):
|
344 |
+
_validate_input_format(tokenize_sentences, predictions, references_invalid, documents)
|
345 |
+
|
346 |
+
with self.assertRaises(ValueError):
|
347 |
+
_validate_input_format(tokenize_sentences, predictions, references, documents_invalid)
|
348 |
+
|
349 |
+
|
350 |
+
class TestSemnCG(unittest.TestCase):
|
351 |
+
def setUp(self):
|
352 |
+
self.model_name = "stsb-distilbert-base"
|
353 |
+
self.metric = SemnCG(self.model_name)
|
354 |
+
|
355 |
+
def _basic_assertion(self, result, debug: bool = False):
|
356 |
+
self.assertIsInstance(result, tuple)
|
357 |
+
self.assertEqual(len(result), 2)
|
358 |
+
self.assertIsInstance(result[0], float)
|
359 |
+
self.assertTrue(0.0 <= result[0] <= 1.0)
|
360 |
+
self.assertIsInstance(result[1], list)
|
361 |
+
if debug:
|
362 |
+
for ranked_gain in result[1]:
|
363 |
+
self.assertTrue(isinstance(ranked_gain, RankedGains))
|
364 |
+
self.assertTrue(0.0 <= ranked_gain.ncg <= 1.0)
|
365 |
+
else:
|
366 |
+
for gain in result[1]:
|
367 |
+
self.assertTrue(isinstance(gain, float))
|
368 |
+
self.assertTrue(0.0 <= gain <= 1.0)
|
369 |
+
|
370 |
+
def test_compute_basic(self):
|
371 |
+
predictions = ["The cat sat on the mat.", "The quick brown fox jumps over the lazy dog."]
|
372 |
+
references = ["A cat was sitting on a mat.", "A quick brown fox jumped over a lazy dog."]
|
373 |
+
documents = ["There was a cat on a mat.", "The quick brown fox jumped over the lazy dog."]
|
374 |
+
|
375 |
+
result = self.metric.compute(predictions=predictions, references=references, documents=documents)
|
376 |
+
self._basic_assertion(result)
|
377 |
+
|
378 |
+
def test_compute_with_tokenization(self):
|
379 |
+
predictions = [["The cat sat on the mat."], ["The quick brown fox jumps over the lazy dog."]]
|
380 |
+
references = [["A cat was sitting on a mat."], ["A quick brown fox jumped over a lazy dog."]]
|
381 |
+
documents = [["There was a cat on a mat."], ["The quick brown fox jumped over the lazy dog."]]
|
382 |
+
|
383 |
+
result = self.metric.compute(
|
384 |
+
predictions=predictions, references=references, documents=documents, tokenize_sentences=False
|
385 |
+
)
|
386 |
+
self._basic_assertion(result)
|
387 |
+
|
388 |
+
def test_compute_with_pre_compute_embeddings(self):
|
389 |
+
predictions = ["The cat sat on the mat.", "The quick brown fox jumps over the lazy dog."]
|
390 |
+
references = ["A cat was sitting on a mat.", "A quick brown fox jumped over a lazy dog."]
|
391 |
+
documents = ["There was a cat on a mat.", "The quick brown fox jumped over the lazy dog."]
|
392 |
+
|
393 |
+
result = self.metric.compute(
|
394 |
+
predictions=predictions, references=references, documents=documents, pre_compute_embeddings=True
|
395 |
+
)
|
396 |
+
self._basic_assertion(result)
|
397 |
+
|
398 |
+
def test_compute_with_debug(self):
|
399 |
+
predictions = ["The cat sat on the mat.", "The quick brown fox jumps over the lazy dog."]
|
400 |
+
references = ["A cat was sitting on a mat.", "A quick brown fox jumped over a lazy dog."]
|
401 |
+
documents = ["There was a cat on a mat.", "The quick brown fox jumped over the lazy dog."]
|
402 |
+
|
403 |
+
result = self.metric.compute(
|
404 |
+
predictions=predictions, references=references, documents=documents, debug=True
|
405 |
+
)
|
406 |
+
self._basic_assertion(result, debug=True)
|
407 |
+
|
408 |
+
def test_compute_invalid_input_format(self):
|
409 |
+
predictions = "The cat sat on the mat."
|
410 |
+
references = ["A cat was sitting on a mat."]
|
411 |
+
documents = ["There was a cat on a mat."]
|
412 |
+
|
413 |
+
with self.assertRaises(ValueError):
|
414 |
+
self.metric.compute(predictions=predictions, references=references, documents=documents)
|
415 |
+
|
416 |
+
|
417 |
+
if __name__ == '__main__':
|
418 |
+
unittest.main(verbosity=2)
|
type_aliases.py
ADDED
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
from typing import List, Union, Tuple
|
3 |
+
|
4 |
+
from numpy.typing import NDArray
|
5 |
+
|
6 |
+
|
7 |
+
NumSentencesType = Union[List[int], List[List[int]]]
|
8 |
+
EmbeddingSlicesType = Union[List[NDArray], List[List[NDArray]]]
|
9 |
+
DEVICE_TYPE = Union[bool, str, int, List[Union[str, int]]]
|
10 |
+
ENCODER_DEVICE_TYPE = Union[str, int, List[Union[str, int]]]
|
11 |
+
DOCUMENT_TYPE = Union[List[str], List[List[str]]]
|
utils.py
ADDED
@@ -0,0 +1,280 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
import string
|
3 |
+
from typing import List, Tuple, Union
|
4 |
+
|
5 |
+
import nltk
|
6 |
+
import numpy as np
|
7 |
+
from numpy.typing import NDArray
|
8 |
+
import torch
|
9 |
+
|
10 |
+
from .type_aliases import DEVICE_TYPE, ENCODER_DEVICE_TYPE, NumSentencesType, EmbeddingSlicesType
|
11 |
+
|
12 |
+
|
13 |
+
def get_gpu(gpu: DEVICE_TYPE) -> ENCODER_DEVICE_TYPE:
|
14 |
+
"""
|
15 |
+
Determine the correct GPU device based on the provided input. In the following, output 0 means CUDA device 0.
|
16 |
+
|
17 |
+
Args:
|
18 |
+
gpu (Union[bool, str, int, List[Union[str, int]]]): Input specifying the GPU device(s):
|
19 |
+
- bool: If True, returns 0 if CUDA is available, otherwise returns "cpu".
|
20 |
+
- str: Can be "cpu", "gpu", or "cuda" (case-insensitive). Returns 0 if CUDA is available
|
21 |
+
and the input is not "cpu", otherwise returns "cpu".
|
22 |
+
- int: Should be a valid GPU index. Returns the index if CUDA is available and valid,
|
23 |
+
otherwise returns "cpu".
|
24 |
+
- List[Union[str, int]]: List containing combinations of the str/int. Processes each
|
25 |
+
element and returns a list of corresponding results.
|
26 |
+
|
27 |
+
Returns:
|
28 |
+
Union[str, int, List[Union[str, int]]]: Depending on the input type:
|
29 |
+
- str: Returns "cpu" if no GPU is available or the input is "cpu".
|
30 |
+
- int: Returns the GPU index if valid and CUDA is available.
|
31 |
+
- List[Union[str, int]]: Returns a list of strings and/or integers based on the input list.
|
32 |
+
|
33 |
+
Raises:
|
34 |
+
ValueError: If the input gpu type is not recognized or invalid.
|
35 |
+
ValueError: If a string input is not one of ["cpu", "gpu", "cuda"].
|
36 |
+
ValueError: If an integer input is outside the valid range of GPU indices.
|
37 |
+
|
38 |
+
Notes:
|
39 |
+
- This function checks CUDA availability using torch.cuda.is_available() and counts
|
40 |
+
available GPUs using torch.cuda.device_count().
|
41 |
+
- Case insensitivity is maintained for string inputs ("cpu", "gpu", "cuda").
|
42 |
+
- The function ensures robust error handling for invalid input types or out-of-range indices.
|
43 |
+
"""
|
44 |
+
|
45 |
+
# Ensure gpu index is within the range of total available gpus
|
46 |
+
gpu_available = torch.cuda.is_available()
|
47 |
+
gpu_count = torch.cuda.device_count()
|
48 |
+
correct_strs = ["cpu", "gpu", "cuda"]
|
49 |
+
|
50 |
+
def _get_single_device(gpu_item):
|
51 |
+
if isinstance(gpu_item, bool):
|
52 |
+
return 0 if gpu_item and gpu_available else "cpu"
|
53 |
+
elif isinstance(gpu_item, str):
|
54 |
+
if gpu_item.lower() not in correct_strs:
|
55 |
+
raise ValueError(f"Wrong gpu type: {gpu_item}. Should be one of {correct_strs}")
|
56 |
+
return 0 if (gpu_item.lower() != "cpu") and gpu_available else "cpu"
|
57 |
+
elif isinstance(gpu_item, int):
|
58 |
+
if gpu_item >= gpu_count:
|
59 |
+
raise ValueError(
|
60 |
+
f"There are {gpu_count} GPUs available. Provide a valid GPU index. You provided: {gpu_item}"
|
61 |
+
)
|
62 |
+
return gpu_item if gpu_available else "cpu"
|
63 |
+
else:
|
64 |
+
raise ValueError(f"Invalid gpu type: {type(gpu_item)}. Must be bool, str, or int.")
|
65 |
+
|
66 |
+
if isinstance(gpu, list):
|
67 |
+
seen_indices = set()
|
68 |
+
result = []
|
69 |
+
for item in gpu:
|
70 |
+
device = _get_single_device(item)
|
71 |
+
if isinstance(device, int):
|
72 |
+
if device not in seen_indices:
|
73 |
+
seen_indices.add(device)
|
74 |
+
result.append(device)
|
75 |
+
else:
|
76 |
+
result.append(device)
|
77 |
+
return result[0] if len(result) == 1 else result
|
78 |
+
else:
|
79 |
+
return _get_single_device(gpu)
|
80 |
+
|
81 |
+
|
82 |
+
def prep_sentences(sentences: List[str]) -> List[str]:
|
83 |
+
"""
|
84 |
+
Processes a list of sentences by stripping whitespace (at beginning and the end),
|
85 |
+
, filtering out empty sentences or sentences that only contains punctuations.
|
86 |
+
|
87 |
+
Args:
|
88 |
+
sentences (List[str]): A list of sentences to be processed.
|
89 |
+
|
90 |
+
Returns:
|
91 |
+
List[str]: A list of cleaned sentences
|
92 |
+
|
93 |
+
Raises:
|
94 |
+
ValueError: If the resulting list of sentences is empty.
|
95 |
+
|
96 |
+
Example:
|
97 |
+
>>> prep_sentences(["Hello, world!", " This is a test. ", "!!!"])
|
98 |
+
['Hello, world!', 'This is a test.']
|
99 |
+
|
100 |
+
>>> prep_sentences(["!!!", "..."])
|
101 |
+
ValueError: Document can't be empty.
|
102 |
+
"""
|
103 |
+
out = []
|
104 |
+
for sent in sentences:
|
105 |
+
sent = sent.strip()
|
106 |
+
sent_wo_punctuation = (
|
107 |
+
sent.translate(str.maketrans("", "", string.punctuation))
|
108 |
+
).strip()
|
109 |
+
if sent_wo_punctuation:
|
110 |
+
out.append(sent)
|
111 |
+
|
112 |
+
if len(out) == 0:
|
113 |
+
raise ValueError("Document can't be empty.")
|
114 |
+
return out
|
115 |
+
|
116 |
+
|
117 |
+
def tokenize_and_prep_document(document: Union[str, List[str]], tokenize: bool) -> List[str]:
|
118 |
+
"""
|
119 |
+
Tokenizes and prepares a document by either tokenizing it into sentences and processing each sentence,
|
120 |
+
or directly processing each element if `tokenize` is False.
|
121 |
+
|
122 |
+
Args:
|
123 |
+
document (Union[str, List[str]]): The document to be processed. It can be a single string (enitre document) or a
|
124 |
+
list of strings (list of sentences).
|
125 |
+
tokenize (bool): If True, tokenizes `document` into sentences using NLTK's sentence tokenizer before processing.
|
126 |
+
If False, processes each element of `document` directly as sentences.
|
127 |
+
|
128 |
+
Returns:
|
129 |
+
List[str]: A list of cleaned sentences.
|
130 |
+
|
131 |
+
Raises:
|
132 |
+
ValueError: If the resulting list of sentences is empty after processing.
|
133 |
+
|
134 |
+
Example:
|
135 |
+
>>> tokenize_and_prep_document("Hello, world! This is a test.", True)
|
136 |
+
['Hello, world!', 'This is a test.']
|
137 |
+
|
138 |
+
>>> tokenize_and_prep_document(["Hello, world!", "This is a test."], False)
|
139 |
+
['Hello, world!', 'This is a test.']
|
140 |
+
|
141 |
+
>>> tokenize_and_prep_document("!!! ...", True)
|
142 |
+
ValueError: Document can't be empty.
|
143 |
+
|
144 |
+
Note: Only the following two cases are possible.
|
145 |
+
tokenizer=True -> document: str
|
146 |
+
tokenizer=False -> document: List[str].
|
147 |
+
"""
|
148 |
+
if tokenize:
|
149 |
+
return prep_sentences(nltk.tokenize.sent_tokenize(document))
|
150 |
+
return prep_sentences(document)
|
151 |
+
|
152 |
+
|
153 |
+
def flatten_list(nested_list: list) -> list:
|
154 |
+
"""
|
155 |
+
Recursively flattens a nested list of any depth.
|
156 |
+
|
157 |
+
Parameters:
|
158 |
+
nested_list (list): The nested list to flatten.
|
159 |
+
|
160 |
+
Returns:
|
161 |
+
list: A flat list containing all the elements of the nested list.
|
162 |
+
"""
|
163 |
+
flat_list = []
|
164 |
+
for item in nested_list:
|
165 |
+
if isinstance(item, list):
|
166 |
+
flat_list.extend(flatten_list(item))
|
167 |
+
else:
|
168 |
+
flat_list.append(item)
|
169 |
+
return flat_list
|
170 |
+
|
171 |
+
|
172 |
+
def is_nested_list_of_type(lst_obj, element_type, depth: int) -> bool:
|
173 |
+
"""
|
174 |
+
Check if the given object is a nested list of a specific type up to a specified depth.
|
175 |
+
|
176 |
+
Args:
|
177 |
+
- lst_obj: The object to check, expected to be a list or a single element.
|
178 |
+
- element_type: The type that each element in the nested list should match.
|
179 |
+
- depth (int): The depth of nesting to check. Must be non-negative.
|
180 |
+
|
181 |
+
Returns:
|
182 |
+
- bool: True if lst_obj is a nested list of the specified type up to the given depth, False otherwise.
|
183 |
+
|
184 |
+
Raises:
|
185 |
+
- ValueError: If depth is negative.
|
186 |
+
|
187 |
+
Example:
|
188 |
+
```python
|
189 |
+
# Test cases
|
190 |
+
is_nested_list_of_type("test", str, 0) # Returns True
|
191 |
+
is_nested_list_of_type([1, 2, 3], str, 0) # Returns False
|
192 |
+
is_nested_list_of_type(["apple", "banana"], str, 1) # Returns True
|
193 |
+
is_nested_list_of_type([[1, 2], [3, 4]], int, 2) # Returns True
|
194 |
+
is_nested_list_of_type([[1, 2], ["a", "b"]], int, 2) # Returns False
|
195 |
+
is_nested_list_of_type([[[1], [2]], [[3], [4]]], int, 3) # Returns True
|
196 |
+
```
|
197 |
+
|
198 |
+
Explanation:
|
199 |
+
- The function checks if `lst_obj` is a nested list of elements of type `element_type` up to `depth` levels deep.
|
200 |
+
- If `depth` is 0, it checks if `lst_obj` itself is of type `element_type`.
|
201 |
+
- If `depth` is greater than 0, it recursively checks each level of nesting to ensure all elements match `element_type`.
|
202 |
+
- Raises a `ValueError` if `depth` is negative, as depth must be a non-negative integer.
|
203 |
+
"""
|
204 |
+
if depth == 0:
|
205 |
+
return isinstance(lst_obj, element_type)
|
206 |
+
elif depth > 0:
|
207 |
+
return isinstance(lst_obj, list) and all(is_nested_list_of_type(item, element_type, depth - 1) for item in lst_obj)
|
208 |
+
else:
|
209 |
+
raise ValueError("Depth can't be negative")
|
210 |
+
|
211 |
+
|
212 |
+
def slice_embeddings(embeddings: NDArray, num_sentences: NumSentencesType) -> EmbeddingSlicesType:
|
213 |
+
"""
|
214 |
+
Slice embeddings into segments based on the provided number of sentences per segment.
|
215 |
+
|
216 |
+
Args:
|
217 |
+
- embeddings (np.ndarray): The array of embeddings to be sliced.
|
218 |
+
- num_sentences (Union[List[int], List[List[int]]]):
|
219 |
+
- If a list of integers: Specifies the number of embeddings to take in each slice.
|
220 |
+
- If a list of lists of integers: Specifies multiple nested levels of slicing.
|
221 |
+
|
222 |
+
Returns:
|
223 |
+
- List[np.ndarray]: A list of numpy arrays where each array represents a slice of embeddings.
|
224 |
+
|
225 |
+
Raises:
|
226 |
+
- TypeError: If `num_sentences` is not of type List[int] or List[List[int]].
|
227 |
+
|
228 |
+
Example Usage:
|
229 |
+
|
230 |
+
```python
|
231 |
+
embeddings = np.random.rand(10, 5)
|
232 |
+
num_sentences = [3, 2, 5]
|
233 |
+
result = slice_embeddings(embeddings, num_sentences)
|
234 |
+
# `result` will be a list of numpy arrays:
|
235 |
+
# [embeddings[:3], embeddings[3:5], embeddings[5:]]
|
236 |
+
|
237 |
+
num_sentences_nested = [[2, 1], [3, 4]]
|
238 |
+
result_nested = slice_embeddings(embeddings, num_sentences_nested)
|
239 |
+
# `result_nested` will be a nested list of numpy arrays:
|
240 |
+
# [[embeddings[:2], embeddings[2:3]], [embeddings[3:6], embeddings[6:]]]
|
241 |
+
|
242 |
+
slice_embeddings(embeddings, "invalid") # Raises a TypeError
|
243 |
+
```
|
244 |
+
"""
|
245 |
+
|
246 |
+
def _slice_embeddings(s_idx: int, n_sentences: List[int]):
|
247 |
+
"""
|
248 |
+
Helper function to slice embeddings starting from index `s_idx`.
|
249 |
+
|
250 |
+
Args:
|
251 |
+
- s_idx (int): Starting index for slicing.
|
252 |
+
- n_sentences (List[int]): List specifying number of sentences in each slice.
|
253 |
+
|
254 |
+
Returns:
|
255 |
+
- Tuple[List[np.ndarray], int]: A tuple containing a list of sliced embeddings and the next starting index.
|
256 |
+
"""
|
257 |
+
_result = []
|
258 |
+
for count in n_sentences:
|
259 |
+
_result.append(embeddings[s_idx:s_idx + count])
|
260 |
+
s_idx += count
|
261 |
+
return _result, s_idx
|
262 |
+
|
263 |
+
if isinstance(num_sentences, list) and all(isinstance(item, int) for item in num_sentences):
|
264 |
+
result, _ = _slice_embeddings(0, num_sentences)
|
265 |
+
return result
|
266 |
+
elif isinstance(num_sentences, list) and all(
|
267 |
+
isinstance(sublist, list) and all(
|
268 |
+
isinstance(item, int) for item in sublist
|
269 |
+
)
|
270 |
+
for sublist in num_sentences
|
271 |
+
):
|
272 |
+
nested_result = []
|
273 |
+
start_idx = 0
|
274 |
+
for nested_num_sentences in num_sentences:
|
275 |
+
embedding_slice, start_idx = _slice_embeddings(start_idx, nested_num_sentences)
|
276 |
+
nested_result.append(embedding_slice)
|
277 |
+
|
278 |
+
return nested_result
|
279 |
+
else:
|
280 |
+
raise TypeError(f"Incorrect Type for {num_sentences=}")
|