Sem-nCG / tests.py
nbansal's picture
Handled the edge cases and added better error message
e0e4e28
raw
history blame contribute delete
No virus
20.1 kB
import unittest
import numpy as np
import torch
from sentence_transformers import SentenceTransformer
from .encoder_models import SBertEncoder, get_encoder, get_sbert_encoder
from .semncg import (
RankedGains,
compute_cosine_similarity,
compute_gain,
score_ncg,
compute_ncg,
_validate_input_format,
SemNCG
)
from .utils import (
get_gpu,
slice_embeddings,
is_nested_list_of_type,
flatten_list,
prep_sentences,
tokenize_and_prep_document
)
class TestUtils(unittest.TestCase):
def test_get_gpu(self):
gpu_count = torch.cuda.device_count()
gpu_available = torch.cuda.is_available()
# Test single boolean input
self.assertEqual(get_gpu(True), 0 if gpu_available else "cpu")
self.assertEqual(get_gpu(False), "cpu")
# Test single string input
self.assertEqual(get_gpu("cpu"), "cpu")
self.assertEqual(get_gpu("gpu"), 0 if gpu_available else "cpu")
self.assertEqual(get_gpu("cuda"), 0 if gpu_available else "cpu")
# Test single integer input
self.assertEqual(get_gpu(0), 0 if gpu_available else "cpu")
self.assertEqual(get_gpu(1), 1 if gpu_available else "cpu")
# Test list input with unique elements
self.assertEqual(get_gpu([True, "cpu", 0]), [0, "cpu"] if gpu_available else ["cpu", "cpu", "cpu"])
# Test list input with duplicate elements
self.assertEqual(get_gpu([0, 0, "gpu"]), 0 if gpu_available else ["cpu", "cpu", "cpu"])
# Test list input with duplicate elements of different types
self.assertEqual(get_gpu([True, 0, "gpu"]), 0 if gpu_available else ["cpu", "cpu", "cpu"])
# Test list input but only one element
self.assertEqual(get_gpu([True]), 0 if gpu_available else "cpu")
# Test list input with all integers
self.assertEqual(get_gpu(list(range(gpu_count))),
list(range(gpu_count)) if gpu_available else gpu_count * ["cpu"])
with self.assertRaises(ValueError):
get_gpu("invalid")
with self.assertRaises(ValueError):
get_gpu(torch.cuda.device_count())
def test_prep_sentences(self):
# Test normal case
self.assertEqual(prep_sentences(["Hello, world!", " This is a test. ", "!!!"]),
['Hello, world!', 'This is a test.'])
# Test case with only punctuations
with self.assertRaises(ValueError):
prep_sentences(["!!!", "..."])
# Test case with empty list
with self.assertRaises(ValueError):
prep_sentences([])
def test_tokenize_and_prep_document(self):
# Test tokenize=True with string input
self.assertEqual(tokenize_and_prep_document("Hello, world! This is a test.", True),
['Hello, world!', 'This is a test.'])
# Test tokenize=False with list of strings input
self.assertEqual(tokenize_and_prep_document(["Hello, world!", "This is a test."], False),
['Hello, world!', 'This is a test.'])
# Test tokenize=True with empty document
with self.assertRaises(ValueError):
tokenize_and_prep_document("!!! ...", True)
def test_slice_embeddings(self):
# Case 1
embeddings = np.random.rand(10, 5)
num_sentences = [3, 2, 5]
expected_output = [embeddings[:3], embeddings[3:5], embeddings[5:]]
self.assertTrue(
all(np.array_equal(a, b) for a, b in zip(slice_embeddings(embeddings, num_sentences),
expected_output))
)
# Case 2
num_sentences_nested = [[2, 1], [3, 4]]
expected_output_nested = [[embeddings[:2], embeddings[2:3]], [embeddings[3:6], embeddings[6:]]]
self.assertTrue(
slice_embeddings(embeddings, num_sentences_nested), expected_output_nested
)
# Case 3
document_sentences_count = [10, 8, 7]
reference_sentences_count = [5, 3, 2]
pred_sentences_count = [2, 2, 1]
all_embeddings = np.random.rand(
sum(document_sentences_count + reference_sentences_count + pred_sentences_count), 5,
)
embeddings = all_embeddings
expected_doc_embeddings = [embeddings[:10], embeddings[10:18], embeddings[18:25]]
embeddings = all_embeddings[25:]
expected_ref_embeddings = [embeddings[:5], embeddings[5:8], embeddings[8:10]]
embeddings = all_embeddings[35:]
expected_pred_embeddings = [embeddings[:2], embeddings[2:4], embeddings[4:5]]
doc_embeddings = slice_embeddings(all_embeddings, document_sentences_count)
ref_embeddings = slice_embeddings(all_embeddings[sum(document_sentences_count):], reference_sentences_count)
pred_embeddings = slice_embeddings(
all_embeddings[sum(document_sentences_count + reference_sentences_count):], pred_sentences_count
)
self.assertTrue(doc_embeddings, expected_doc_embeddings)
self.assertTrue(ref_embeddings, expected_ref_embeddings)
self.assertTrue(pred_embeddings, expected_pred_embeddings)
with self.assertRaises(TypeError):
slice_embeddings(embeddings, "invalid")
def test_is_nested_list_of_type(self):
# Test case: Depth 0, single element matching element_type
self.assertEqual(is_nested_list_of_type("test", str, 0), (True, ""))
# Test case: Depth 0, single element not matching element_type
is_valid, err_msg = is_nested_list_of_type("test", int, 0)
self.assertEqual(is_valid, False)
# Test case: Depth 1, list of elements matching element_type
self.assertEqual(is_nested_list_of_type(["apple", "banana"], str, 1), (True, ""))
# Test case: Depth 1, list of elements not matching element_type
is_valid, err_msg = is_nested_list_of_type([1, 2, 3], str, 1)
self.assertEqual(is_valid, False)
# Test case: Depth 0 (Wrong), list of elements matching element_type
is_valid, err_msg = is_nested_list_of_type([1, 2, 3], str, 0)
self.assertEqual(is_valid, False)
# Depth 2
self.assertEqual(is_nested_list_of_type([[1, 2], [3, 4]], int, 2), (True, ""))
self.assertEqual(is_nested_list_of_type([['1', '2'], ['3', '4']], str, 2), (True, ""))
is_valid, err_msg = is_nested_list_of_type([[1, 2], ["a", "b"]], int, 2)
self.assertEqual(is_valid, False)
# Depth 3
is_valid, err_msg = is_nested_list_of_type([[[1], [2]], [[3], [4]]], list, 3)
self.assertEqual(is_valid, False)
self.assertEqual(is_nested_list_of_type([[[1], [2]], [[3], [4]]], int, 3), (True, ""))
# Test case: Depth is negative, expecting ValueError
with self.assertRaises(ValueError):
is_nested_list_of_type([1, 2], int, -1)
def test_flatten_list(self):
self.assertEqual(flatten_list([1, [2, 3], [[4], 5]]), [1, 2, 3, 4, 5])
self.assertEqual(flatten_list([]), [])
self.assertEqual(flatten_list([1, 2, 3]), [1, 2, 3])
self.assertEqual(flatten_list([[[[1]]]]), [1])
class TestSBertEncoder(unittest.TestCase):
def setUp(self) -> None:
# Set up a test SentenceTransformer model
self.model_name = "paraphrase-distilroberta-base-v1"
self.sbert_model = get_sbert_encoder(self.model_name)
self.device = "cpu" # For testing on CPU
self.batch_size = 32
self.verbose = False
self.encoder = SBertEncoder(self.sbert_model, self.device, self.batch_size, self.verbose)
def test_encode_single_sentence(self):
sentence = "Hello, world!"
embeddings = self.encoder.encode([sentence])
self.assertEqual(embeddings.shape, (1, 768)) # Adjust shape based on your model's embedding dimension
def test_encode_multiple_sentences(self):
sentences = ["Hello, world!", "This is a test."]
embeddings = self.encoder.encode(sentences)
self.assertEqual(embeddings.shape, (2, 768)) # Adjust shape based on your model's embedding dimension
def test_get_sbert_encoder(self):
model_name = "paraphrase-distilroberta-base-v1"
sbert_model = get_sbert_encoder(model_name)
self.assertIsInstance(sbert_model, SentenceTransformer)
def test_encode_with_gpu(self):
if torch.cuda.is_available():
device = "cuda"
encoder = get_encoder(self.sbert_model, device, self.batch_size, self.verbose)
sentences = ["Hello, world!", "This is a test."]
embeddings = encoder.encode(sentences)
self.assertEqual(embeddings.shape, (2, 768)) # Adjust shape based on your model's embedding dimension
else:
self.skipTest("CUDA not available, skipping GPU test.")
def test_encode_multi_device(self):
if torch.cuda.device_count() < 2:
self.skipTest("Multi-GPU test requires at least 2 GPUs.")
else:
devices = ["cuda:0", "cuda:1"]
encoder = get_encoder(self.sbert_model, devices, self.batch_size, self.verbose)
sentences = ["This is a test sentence.", "Here is another sentence.", "This is a test sentence."]
embeddings = encoder.encode(sentences)
self.assertIsInstance(embeddings, np.ndarray)
self.assertEqual(embeddings.shape[0], 3)
self.assertEqual(embeddings.shape[1], self.encoder.model.get_sentence_embedding_dimension())
class TestGetEncoder(unittest.TestCase):
def setUp(self):
self.device = "cuda" if torch.cuda.is_available() else "cpu"
self.batch_size = 8
self.verbose = False
def _base_test(self, model_name):
sbert_model = get_sbert_encoder(model_name)
encoder = get_encoder(sbert_model, self.device, self.batch_size, self.verbose)
# Assert
self.assertIsInstance(encoder, SBertEncoder)
self.assertEqual(encoder.device, self.device)
self.assertEqual(encoder.batch_size, self.batch_size)
self.assertEqual(encoder.verbose, self.verbose)
def test_get_sbert_encoder(self):
model_name = "stsb-roberta-large"
self._base_test(model_name)
def test_sbert_model(self):
model_name = "all-mpnet-base-v2"
self._base_test(model_name)
def test_huggingface_model(self):
"""Test Huggingface models which work with SBert library"""
model_name = "roberta-base"
self._base_test(model_name)
def test_get_encoder_environment_error(self): # This parameter is used when using patch decorator
model_name = "abc" # Wrong model_name
with self.assertRaises(EnvironmentError):
get_sbert_encoder(model_name)
def test_get_encoder_other_exception(self):
model_name = "apple/OpenELM-270M" # This model is not supported by SentenceTransformer lib
with self.assertRaises(RuntimeError):
get_sbert_encoder(model_name)
class TestRankedGainsDataclass(unittest.TestCase):
def test_ranked_gains_dataclass(self):
# Test initialization and attribute access
gt_gains = [("doc1", 0.8), ("doc2", 0.6)]
pred_gains = [("doc2", 0.7), ("doc1", 0.5)]
k = 2
ncg = 0.75
ranked_gains = RankedGains(gt_gains, pred_gains, k, ncg)
self.assertEqual(ranked_gains.gt_gains, gt_gains)
self.assertEqual(ranked_gains.pred_gains, pred_gains)
self.assertEqual(ranked_gains.k, k)
self.assertEqual(ranked_gains.ncg, ncg)
class TestComputeCosineSimilarity(unittest.TestCase):
def test_compute_cosine_similarity(self):
doc_embeds = np.array([[0.1, 0.2, 0.3], [0.4, 0.5, 0.6]])
ref_embeds = np.array([[0.2, 0.3, 0.4], [0.5, 0.6, 0.7]])
# Test compute_cosine_similarity function
similarity_scores = compute_cosine_similarity(doc_embeds, ref_embeds)
print(similarity_scores)
# Example values, change as per actual function output
expected_scores = [0.980, 0.997]
self.assertAlmostEqual(similarity_scores[0], expected_scores[0], places=3)
self.assertAlmostEqual(similarity_scores[1], expected_scores[1], places=3)
class TestComputeGain(unittest.TestCase):
def test_compute_gain(self):
# Test compute_gain function
sim_scores = [0.8, 0.6, 0.7]
gains = compute_gain(sim_scores)
print(gains)
# Example values, change as per actual function output
expected_gains = [(0, 0.5), (2, 0.3333333333333333), (1, 0.16666666666666666)]
self.assertEqual(gains, expected_gains)
class TestScoreNcg(unittest.TestCase):
def test_score_ncg(self):
# Test score_ncg function
model_relevance = [0.8, 0.7, 0.6]
gt_relevance = [1.0, 0.9, 0.8]
ncg_score = score_ncg(model_relevance, gt_relevance)
expected_ncg = 0.778 # Example value, change as per actual function output
self.assertAlmostEqual(ncg_score, expected_ncg, places=3)
class TestComputeNcg(unittest.TestCase):
def test_compute_ncg(self):
# Test compute_ncg function
pred_gains = [(0, 0.8), (2, 0.7), (1, 0.6)]
gt_gains = [(0, 1.0), (1, 0.9), (2, 0.8)]
k = 3
ncg_score = compute_ncg(pred_gains, gt_gains, k)
expected_ncg = 1.0 # TODO: Confirm this with Dr. Santu
self.assertAlmostEqual(ncg_score, expected_ncg, places=6)
class TestValidateInputFormat(unittest.TestCase):
def test_validate_input_format(self):
# Test _validate_input_format function
tokenize_sentences = True
predictions = ["Prediction 1", "Prediction 2"]
references = ["Reference 1", "Reference 2"]
documents = ["Document 1", "Document 2"]
# No exception should be raised for valid input
try:
_validate_input_format(tokenize_sentences, predictions, references, documents)
except ValueError as e:
self.fail(f"_validate_input_format raised ValueError unexpectedly: {str(e)}")
# Test invalid input format
predictions_invalid = [["Sentence 1 in prediction 1.", "Sentence 2 in prediction 1."],
["Sentence 1 in prediction 2.", "Sentence 2 in prediction 2."]]
references_invalid = [["Sentences in reference 1."], ["Sentences in reference 2."]]
documents_invalid = [["Sentence 1 in document 1.", "Sentence 2 in document 1."],
["Sentence 1 in document 2.", "Sentence 2 in document 2."]]
with self.assertRaises(ValueError):
_validate_input_format(tokenize_sentences, predictions_invalid, references, documents)
with self.assertRaises(ValueError):
_validate_input_format(tokenize_sentences, predictions, references_invalid, documents)
with self.assertRaises(ValueError):
_validate_input_format(tokenize_sentences, predictions, references, documents_invalid)
class TestSemNCG(unittest.TestCase):
def setUp(self):
self.model_name = "stsb-distilbert-base"
self.metric = SemNCG(self.model_name)
def _basic_assertion(self, result, debug: bool = False):
self.assertIsInstance(result, tuple)
self.assertEqual(len(result), 2)
self.assertIsInstance(result[0], float)
self.assertTrue(0.0 <= result[0] <= 1.0)
self.assertIsInstance(result[1], list)
if debug:
for ranked_gain in result[1]:
self.assertTrue(isinstance(ranked_gain, RankedGains))
self.assertTrue(0.0 <= ranked_gain.ncg <= 1.0)
else:
for gain in result[1]:
self.assertTrue(isinstance(gain, float))
self.assertTrue(0.0 <= gain <= 1.0)
def test_compute_basic(self):
predictions = ["The cat sat on the mat.", "The quick brown fox jumps over the lazy dog."]
references = ["A cat was sitting on a mat.", "A quick brown fox jumped over a lazy dog."]
documents = ["There was a cat on a mat.", "The quick brown fox jumped over the lazy dog."]
result = self.metric.compute(predictions=predictions, references=references, documents=documents)
self._basic_assertion(result)
def test_compute_with_tokenization(self):
predictions = [["The cat sat on the mat."], ["The quick brown fox jumps over the lazy dog."]]
references = [["A cat was sitting on a mat."], ["A quick brown fox jumped over a lazy dog."]]
documents = [["There was a cat on a mat."], ["The quick brown fox jumped over the lazy dog."]]
result = self.metric.compute(
predictions=predictions, references=references, documents=documents, tokenize_sentences=False
)
self._basic_assertion(result)
def test_compute_with_pre_compute_embeddings(self):
predictions = ["The cat sat on the mat.", "The quick brown fox jumps over the lazy dog."]
references = ["A cat was sitting on a mat.", "A quick brown fox jumped over a lazy dog."]
documents = ["There was a cat on a mat.", "The quick brown fox jumped over the lazy dog."]
result = self.metric.compute(
predictions=predictions, references=references, documents=documents, pre_compute_embeddings=True
)
self._basic_assertion(result)
def test_compute_with_debug(self):
predictions = ["The cat sat on the mat.", "The quick brown fox jumps over the lazy dog."]
references = ["A cat was sitting on a mat.", "A quick brown fox jumped over a lazy dog."]
documents = ["There was a cat on a mat.", "The quick brown fox jumped over the lazy dog."]
result = self.metric.compute(
predictions=predictions, references=references, documents=documents, debug=True
)
self._basic_assertion(result, debug=True)
def test_compute_invalid_input_format(self):
predictions = "The cat sat on the mat."
references = ["A cat was sitting on a mat."]
documents = ["There was a cat on a mat."]
with self.assertRaises(ValueError):
self.metric.compute(predictions=predictions, references=references, documents=documents)
def test_bad_inputs(self):
def _call_metric(preds, refs, docs, tok):
with self.assertRaises(Exception) as ctx:
_ = self.metric.compute(
predictions=preds,
references=refs,
documents=docs,
tokenize_sentences=tok,
pre_compute_embeddings=True,
)
print(f"Raised Exception with message: {ctx.exception}")
return ""
# None Inputs
# Case I
tokenize_sentences = True
predictions = [None]
references = ["A cat was sitting on a mat."]
documents = ["There was a cat on a mat."]
print(f"Case I\n{_call_metric(predictions, references, documents, tokenize_sentences)}\n")
# Case II
tokenize_sentences = False
predictions = [["A cat was sitting on a mat.", None]]
references = [["A cat was sitting on a mat.", "A cat was sitting on a mat."]]
documents = [["There was a cat on a mat.", "There was a cat on a mat."]]
print(f"Case II\n{_call_metric(predictions, references, documents, tokenize_sentences)}\n")
# Empty Input
tokenize_sentences = True
predictions = []
references = ["A cat was sitting on a mat."]
documents = ["There was a cat on a mat."]
print(f"Case: Empty Input\n{_call_metric(predictions, references, documents, tokenize_sentences)}\n")
# Empty String Input
tokenize_sentences = True
predictions = [""]
references = ["A cat was sitting on a mat."]
documents = ["There was a cat on a mat."]
print(f"Case: Empty String Input\n{_call_metric(predictions, references, documents, tokenize_sentences)}\n")
if __name__ == '__main__':
unittest.main(verbosity=2)