import unittest from unittest.mock import MagicMock, patch from src.prompt_loader import PromptLoader class TestPromptLoader(unittest.TestCase): def setUp(self) -> None: # Set up a mock dataset for testing self.mock_data = {"train": {"prompt": ["prompt1", "prompt2", "prompt3"]}} self.loader = PromptLoader(seed=42) @patch("src.prompt_loader.load_dataset") def test_load_data_without_size(self, mock_load_dataset: MagicMock) -> None: mock_load_dataset.return_value = self.mock_data self.loader.load_data() self.assertEqual(self.loader.data, ["prompt1", "prompt2", "prompt3"]) @patch("src.prompt_loader.load_dataset") def test_load_data_with_size(self, mock_load_dataset: MagicMock) -> None: mock_load_dataset.return_value = self.mock_data self.loader.load_data() sampled_data = self.loader.load_data(size=2) self.assertEqual(len(sampled_data), 2) self.assertTrue(set(sampled_data).issubset({"prompt1", "prompt2", "prompt3"})) @patch("src.prompt_loader.load_dataset") def test_load_data_size_exceeds(self, mock_load_dataset: MagicMock) -> None: mock_load_dataset.return_value = self.mock_data self.loader.load_data() with self.assertRaises(ValueError): self.loader.load_data(size=10) @patch("src.prompt_loader.load_dataset") def test_data_loading_on_demand(self, mock_load_dataset: MagicMock) -> None: mock_load_dataset.return_value = self.mock_data mock_load_dataset.assert_not_called() self.loader.load_data() mock_load_dataset.assert_called_once() @patch("src.prompt_loader.load_dataset") def test_random_sampling(self, mock_load_dataset: MagicMock) -> None: mock_load_dataset.return_value = self.mock_data self.loader.load_data() sample = self.loader.load_data(size=2) self.assertEqual(len(sample), 2) self.assertTrue(set(sample).issubset({"prompt1", "prompt2", "prompt3"})) self.assertNotEqual(sample, ["prompt1", "prompt2"]) if __name__ == "__main__": unittest.main()