Spaces:
Sleeping
Sleeping
File size: 2,142 Bytes
da82b2b 0b497e7 da82b2b |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 |
import unittest
from unittest.mock import MagicMock, patch
from src.prompt_loader import PromptLoader
class TestPromptLoader(unittest.TestCase):
def setUp(self) -> None:
# Set up a mock dataset for testing
self.mock_data = {"train": {"prompt": ["prompt1", "prompt2", "prompt3"]}}
self.loader = PromptLoader(seed=42)
@patch("src.prompt_loader.load_dataset")
def test_load_data_without_size(self, mock_load_dataset: MagicMock) -> None:
mock_load_dataset.return_value = self.mock_data
self.loader.load_data()
self.assertEqual(self.loader.data, ["prompt1", "prompt2", "prompt3"])
@patch("src.prompt_loader.load_dataset")
def test_load_data_with_size(self, mock_load_dataset: MagicMock) -> None:
mock_load_dataset.return_value = self.mock_data
self.loader.load_data()
sampled_data = self.loader.load_data(size=2)
self.assertEqual(len(sampled_data), 2)
self.assertTrue(set(sampled_data).issubset({"prompt1", "prompt2", "prompt3"}))
@patch("src.prompt_loader.load_dataset")
def test_load_data_size_exceeds(self, mock_load_dataset: MagicMock) -> None:
mock_load_dataset.return_value = self.mock_data
self.loader.load_data()
with self.assertRaises(ValueError):
self.loader.load_data(size=10)
@patch("src.prompt_loader.load_dataset")
def test_data_loading_on_demand(self, mock_load_dataset: MagicMock) -> None:
mock_load_dataset.return_value = self.mock_data
mock_load_dataset.assert_not_called()
self.loader.load_data()
mock_load_dataset.assert_called_once()
@patch("src.prompt_loader.load_dataset")
def test_random_sampling(self, mock_load_dataset: MagicMock) -> None:
mock_load_dataset.return_value = self.mock_data
self.loader.load_data()
sample = self.loader.load_data(size=2)
self.assertEqual(len(sample), 2)
self.assertTrue(set(sample).issubset({"prompt1", "prompt2", "prompt3"}))
self.assertNotEqual(sample, ["prompt1", "prompt2"])
if __name__ == "__main__":
unittest.main()
|