File size: 2,142 Bytes
da82b2b
0b497e7
 
 
da82b2b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
import unittest
from unittest.mock import MagicMock, patch

from src.prompt_loader import PromptLoader


class TestPromptLoader(unittest.TestCase):

    def setUp(self) -> None:
        # Set up a mock dataset for testing
        self.mock_data = {"train": {"prompt": ["prompt1", "prompt2", "prompt3"]}}
        self.loader = PromptLoader(seed=42)

    @patch("src.prompt_loader.load_dataset")
    def test_load_data_without_size(self, mock_load_dataset: MagicMock) -> None:
        mock_load_dataset.return_value = self.mock_data

        self.loader.load_data()
        self.assertEqual(self.loader.data, ["prompt1", "prompt2", "prompt3"])

    @patch("src.prompt_loader.load_dataset")
    def test_load_data_with_size(self, mock_load_dataset: MagicMock) -> None:
        mock_load_dataset.return_value = self.mock_data
        self.loader.load_data()
        sampled_data = self.loader.load_data(size=2)

        self.assertEqual(len(sampled_data), 2)
        self.assertTrue(set(sampled_data).issubset({"prompt1", "prompt2", "prompt3"}))

    @patch("src.prompt_loader.load_dataset")
    def test_load_data_size_exceeds(self, mock_load_dataset: MagicMock) -> None:
        mock_load_dataset.return_value = self.mock_data
        self.loader.load_data()

        with self.assertRaises(ValueError):
            self.loader.load_data(size=10)

    @patch("src.prompt_loader.load_dataset")
    def test_data_loading_on_demand(self, mock_load_dataset: MagicMock) -> None:
        mock_load_dataset.return_value = self.mock_data
        mock_load_dataset.assert_not_called()

        self.loader.load_data()
        mock_load_dataset.assert_called_once()

    @patch("src.prompt_loader.load_dataset")
    def test_random_sampling(self, mock_load_dataset: MagicMock) -> None:
        mock_load_dataset.return_value = self.mock_data

        self.loader.load_data()
        sample = self.loader.load_data(size=2)

        self.assertEqual(len(sample), 2)
        self.assertTrue(set(sample).issubset({"prompt1", "prompt2", "prompt3"}))
        self.assertNotEqual(sample, ["prompt1", "prompt2"])


if __name__ == "__main__":
    unittest.main()