import statistics import unittest import numpy as np import torch from sentence_transformers import SentenceTransformer from encoder_models import SBertEncoder, get_encoder from utils import get_gpu, slice_embeddings, is_nested_list_of_type, flatten_list, compute_f1, Scores class TestUtils(unittest.TestCase): def test_get_gpu(self): gpu_count = torch.cuda.device_count() gpu_available = torch.cuda.is_available() # Test single boolean input self.assertEqual(get_gpu(True), 0 if gpu_available else "cpu") self.assertEqual(get_gpu(False), "cpu") # Test single string input self.assertEqual(get_gpu("cpu"), "cpu") self.assertEqual(get_gpu("gpu"), 0 if gpu_available else "cpu") self.assertEqual(get_gpu("cuda"), 0 if gpu_available else "cpu") # Test single integer input self.assertEqual(get_gpu(0), 0 if gpu_available else "cpu") self.assertEqual(get_gpu(1), 1 if gpu_available else "cpu") # Test list input with unique elements self.assertEqual(get_gpu([True, "cpu", 0]), [0, "cpu"] if gpu_available else ["cpu", "cpu", "cpu"]) # Test list input with duplicate elements self.assertEqual(get_gpu([0, 0, "gpu"]), [0] if gpu_available else ["cpu", "cpu", "cpu"]) # Test list input with duplicate elements of different types self.assertEqual(get_gpu([True, 0, "gpu"]), [0] if gpu_available else ["cpu", "cpu", "cpu"]) # Test list input with all integers self.assertEqual(get_gpu(list(range(gpu_count))), list(range(gpu_count)) if gpu_available else gpu_count * ["cpu"]) with self.assertRaises(ValueError): get_gpu("invalid") with self.assertRaises(ValueError): get_gpu(torch.cuda.device_count()) def test_slice_embeddings(self): embeddings = np.random.rand(10, 5) num_sentences = [3, 2, 5] expected_output = [embeddings[:3], embeddings[3:5], embeddings[5:]] self.assertTrue( all(np.array_equal(a, b) for a, b in zip(slice_embeddings(embeddings, num_sentences), expected_output)) ) num_sentences_nested = [[2, 1], [3, 4]] expected_output_nested = [[embeddings[:2], embeddings[2:3]], [embeddings[3:6], embeddings[6:]]] self.assertTrue( slice_embeddings(embeddings, num_sentences_nested), expected_output_nested ) with self.assertRaises(TypeError): slice_embeddings(embeddings, "invalid") def test_is_nested_list_of_type(self): # Test case: Depth 0, single element matching element_type self.assertTrue(is_nested_list_of_type("test", str, 0)) # Test case: Depth 0, single element not matching element_type self.assertFalse(is_nested_list_of_type("test", int, 0)) # Test case: Depth 1, list of elements matching element_type self.assertTrue(is_nested_list_of_type(["apple", "banana"], str, 1)) # Test case: Depth 1, list of elements not matching element_type self.assertFalse(is_nested_list_of_type([1, 2, 3], str, 1)) # Test case: Depth 0 (Wrong), list of elements matching element_type self.assertFalse(is_nested_list_of_type([1, 2, 3], str, 0)) # Depth 2 self.assertTrue(is_nested_list_of_type([[1, 2], [3, 4]], int, 2)) self.assertTrue(is_nested_list_of_type([['1', '2'], ['3', '4']], str, 2)) self.assertFalse(is_nested_list_of_type([[1, 2], ["a", "b"]], int, 2)) # Depth 3 self.assertFalse(is_nested_list_of_type([[[1], [2]], [[3], [4]]], list, 3)) self.assertTrue(is_nested_list_of_type([[[1], [2]], [[3], [4]]], int, 3)) with self.assertRaises(ValueError): is_nested_list_of_type([1, 2], int, -1) def test_flatten_list(self): self.assertEqual(flatten_list([1, [2, 3], [[4], 5]]), [1, 2, 3, 4, 5]) self.assertEqual(flatten_list([]), []) self.assertEqual(flatten_list([1, 2, 3]), [1, 2, 3]) self.assertEqual(flatten_list([[[[1]]]]), [1]) def test_compute_f1(self): self.assertAlmostEqual(compute_f1(0.5, 0.5), 0.5) self.assertAlmostEqual(compute_f1(1, 0), 0.0) self.assertAlmostEqual(compute_f1(0, 1), 0.0) self.assertAlmostEqual(compute_f1(1, 1), 1.0) def test_scores(self): scores = Scores(precision=0.8, recall=[0.7, 0.9]) self.assertAlmostEqual(scores.f1, compute_f1(0.8, statistics.fmean([0.7, 0.9]))) class TestSBertEncoder(unittest.TestCase): def setUp(self, device=None): if device is None: self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") else: self.device = device self.model_name = "stsb-roberta-large" self.batch_size = 8 self.verbose = False self.encoder = SBertEncoder(self.model_name, self.device, self.batch_size, self.verbose) def test_initialization(self): self.assertIsInstance(self.encoder.model, SentenceTransformer) self.assertEqual(self.encoder.device, self.device) self.assertEqual(self.encoder.batch_size, self.batch_size) self.assertEqual(self.encoder.verbose, self.verbose) def test_encode_single_device(self): sentences = ["This is a test sentence.", "Here is another sentence."] embeddings = self.encoder.encode(sentences) self.assertIsInstance(embeddings, np.ndarray) self.assertEqual(embeddings.shape[0], len(sentences)) self.assertEqual(embeddings.shape[1], self.encoder.model.get_sentence_embedding_dimension()) def test_encode_multi_device(self): if torch.cuda.device_count() < 2: self.skipTest("Multi-GPU test requires at least 2 GPUs.") else: devices = ["cuda:0", "cuda:1"] self.setUp(devices) sentences = ["This is a test sentence.", "Here is another sentence.", "This is a test sentence."] embeddings = self.encoder.encode(sentences) self.assertIsInstance(embeddings, np.ndarray) self.assertEqual(embeddings.shape[0], 3) self.assertEqual(embeddings.shape[1], self.encoder.model.get_sentence_embedding_dimension()) class TestGetEncoder(unittest.TestCase): def test_get_sbert_encoder(self): model_name = "stsb-roberta-large" device = torch.device("cuda" if torch.cuda.is_available() else "cpu") batch_size = 8 verbose = False encoder = get_encoder(model_name, device, batch_size, verbose) self.assertIsInstance(encoder, SBertEncoder) self.assertEqual(encoder.device, device) self.assertEqual(encoder.batch_size, batch_size) self.assertEqual(encoder.verbose, verbose) def test_get_use_encoder(self): model_name = "use" device = torch.device("cuda" if torch.cuda.is_available() else "cpu") batch_size = 8 verbose = False encoder = get_encoder(model_name, device, batch_size, verbose) self.assertIsInstance(encoder, SBertEncoder) # SBertEncoder is returned for "use" for now # Uncomment below when implementing USE class # self.assertIsInstance(encoder, USE) # self.assertEqual(encoder.model_name, model_name) # self.assertEqual(encoder.device, device) # self.assertEqual(encoder.batch_size, batch_size) # self.assertEqual(encoder.verbose, verbose) if __name__ == '__main__': unittest.main()