File size: 7,596 Bytes
a249916
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
import statistics
import unittest

import numpy as np
import torch
from sentence_transformers import SentenceTransformer

from encoder_models import SBertEncoder, get_encoder
from utils import get_gpu, slice_embeddings, is_nested_list_of_type, flatten_list, compute_f1, Scores


class TestUtils(unittest.TestCase):
    def test_get_gpu(self):
        gpu_count = torch.cuda.device_count()
        gpu_available = torch.cuda.is_available()

        # Test single boolean input
        self.assertEqual(get_gpu(True), 0 if gpu_available else "cpu")
        self.assertEqual(get_gpu(False), "cpu")

        # Test single string input
        self.assertEqual(get_gpu("cpu"), "cpu")
        self.assertEqual(get_gpu("gpu"), 0 if gpu_available else "cpu")
        self.assertEqual(get_gpu("cuda"), 0 if gpu_available else "cpu")

        # Test single integer input
        self.assertEqual(get_gpu(0), 0 if gpu_available else "cpu")
        self.assertEqual(get_gpu(1), 1 if gpu_available else "cpu")

        # Test list input with unique elements
        self.assertEqual(get_gpu([True, "cpu", 0]), [0, "cpu"] if gpu_available else ["cpu", "cpu", "cpu"])

        # Test list input with duplicate elements
        self.assertEqual(get_gpu([0, 0, "gpu"]), [0] if gpu_available else ["cpu", "cpu", "cpu"])

        # Test list input with duplicate elements of different types
        self.assertEqual(get_gpu([True, 0, "gpu"]), [0] if gpu_available else ["cpu", "cpu", "cpu"])

        # Test list input with all integers
        self.assertEqual(get_gpu(list(range(gpu_count))),
                         list(range(gpu_count)) if gpu_available else gpu_count * ["cpu"])

        with self.assertRaises(ValueError):
            get_gpu("invalid")

        with self.assertRaises(ValueError):
            get_gpu(torch.cuda.device_count())

    def test_slice_embeddings(self):
        embeddings = np.random.rand(10, 5)
        num_sentences = [3, 2, 5]
        expected_output = [embeddings[:3], embeddings[3:5], embeddings[5:]]
        self.assertTrue(
            all(np.array_equal(a, b) for a, b in zip(slice_embeddings(embeddings, num_sentences),
                                                     expected_output))
        )

        num_sentences_nested = [[2, 1], [3, 4]]
        expected_output_nested = [[embeddings[:2], embeddings[2:3]], [embeddings[3:6], embeddings[6:]]]
        self.assertTrue(
            slice_embeddings(embeddings, num_sentences_nested), expected_output_nested
        )

        with self.assertRaises(TypeError):
            slice_embeddings(embeddings, "invalid")

    def test_is_nested_list_of_type(self):
        # Test case: Depth 0, single element matching element_type
        self.assertTrue(is_nested_list_of_type("test", str, 0))

        # Test case: Depth 0, single element not matching element_type
        self.assertFalse(is_nested_list_of_type("test", int, 0))

        # Test case: Depth 1, list of elements matching element_type
        self.assertTrue(is_nested_list_of_type(["apple", "banana"], str, 1))

        # Test case: Depth 1, list of elements not matching element_type
        self.assertFalse(is_nested_list_of_type([1, 2, 3], str, 1))

        # Test case: Depth 0 (Wrong), list of elements matching element_type
        self.assertFalse(is_nested_list_of_type([1, 2, 3], str, 0))

        # Depth 2
        self.assertTrue(is_nested_list_of_type([[1, 2], [3, 4]], int, 2))
        self.assertTrue(is_nested_list_of_type([['1', '2'], ['3', '4']], str, 2))
        self.assertFalse(is_nested_list_of_type([[1, 2], ["a", "b"]], int, 2))

        # Depth 3
        self.assertFalse(is_nested_list_of_type([[[1], [2]], [[3], [4]]], list, 3))
        self.assertTrue(is_nested_list_of_type([[[1], [2]], [[3], [4]]], int, 3))

        with self.assertRaises(ValueError):
            is_nested_list_of_type([1, 2], int, -1)

    def test_flatten_list(self):
        self.assertEqual(flatten_list([1, [2, 3], [[4], 5]]), [1, 2, 3, 4, 5])
        self.assertEqual(flatten_list([]), [])
        self.assertEqual(flatten_list([1, 2, 3]), [1, 2, 3])
        self.assertEqual(flatten_list([[[[1]]]]), [1])

    def test_compute_f1(self):
        self.assertAlmostEqual(compute_f1(0.5, 0.5), 0.5)
        self.assertAlmostEqual(compute_f1(1, 0), 0.0)
        self.assertAlmostEqual(compute_f1(0, 1), 0.0)
        self.assertAlmostEqual(compute_f1(1, 1), 1.0)

    def test_scores(self):
        scores = Scores(precision=0.8, recall=[0.7, 0.9])
        self.assertAlmostEqual(scores.f1, compute_f1(0.8, statistics.fmean([0.7, 0.9])))


class TestSBertEncoder(unittest.TestCase):
    def setUp(self, device=None):
        if device is None:
            self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        else:
            self.device = device
        self.model_name = "stsb-roberta-large"
        self.batch_size = 8
        self.verbose = False
        self.encoder = SBertEncoder(self.model_name, self.device, self.batch_size, self.verbose)

    def test_initialization(self):
        self.assertIsInstance(self.encoder.model, SentenceTransformer)
        self.assertEqual(self.encoder.device, self.device)
        self.assertEqual(self.encoder.batch_size, self.batch_size)
        self.assertEqual(self.encoder.verbose, self.verbose)

    def test_encode_single_device(self):
        sentences = ["This is a test sentence.", "Here is another sentence."]
        embeddings = self.encoder.encode(sentences)
        self.assertIsInstance(embeddings, np.ndarray)
        self.assertEqual(embeddings.shape[0], len(sentences))
        self.assertEqual(embeddings.shape[1], self.encoder.model.get_sentence_embedding_dimension())

    def test_encode_multi_device(self):
        if torch.cuda.device_count() < 2:
            self.skipTest("Multi-GPU test requires at least 2 GPUs.")
        else:
            devices = ["cuda:0", "cuda:1"]
            self.setUp(devices)
            sentences = ["This is a test sentence.", "Here is another sentence.", "This is a test sentence."]
            embeddings = self.encoder.encode(sentences)
            self.assertIsInstance(embeddings, np.ndarray)
            self.assertEqual(embeddings.shape[0], 3)
            self.assertEqual(embeddings.shape[1], self.encoder.model.get_sentence_embedding_dimension())


class TestGetEncoder(unittest.TestCase):
    def test_get_sbert_encoder(self):
        model_name = "stsb-roberta-large"
        device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        batch_size = 8
        verbose = False

        encoder = get_encoder(model_name, device, batch_size, verbose)
        self.assertIsInstance(encoder, SBertEncoder)
        self.assertEqual(encoder.device, device)
        self.assertEqual(encoder.batch_size, batch_size)
        self.assertEqual(encoder.verbose, verbose)

    def test_get_use_encoder(self):
        model_name = "use"
        device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        batch_size = 8
        verbose = False

        encoder = get_encoder(model_name, device, batch_size, verbose)
        self.assertIsInstance(encoder, SBertEncoder)  # SBertEncoder is returned for "use" for now
        # Uncomment below when implementing USE class
        # self.assertIsInstance(encoder, USE)
        # self.assertEqual(encoder.model_name, model_name)
        # self.assertEqual(encoder.device, device)
        # self.assertEqual(encoder.batch_size, batch_size)
        # self.assertEqual(encoder.verbose, verbose)


if __name__ == '__main__':
    unittest.main()