Spaces:
Sleeping
Sleeping
import unittest | |
from unittest.mock import MagicMock, patch | |
from src.prompt_loader import PromptLoader | |
class TestPromptLoader(unittest.TestCase): | |
def setUp(self) -> None: | |
# Set up a mock dataset for testing | |
self.mock_data = {"train": {"prompt": ["prompt1", "prompt2", "prompt3"]}} | |
self.loader = PromptLoader(seed=42) | |
def test_load_data_without_size(self, mock_load_dataset: MagicMock) -> None: | |
mock_load_dataset.return_value = self.mock_data | |
self.loader.load_data() | |
self.assertEqual(self.loader.data, ["prompt1", "prompt2", "prompt3"]) | |
def test_load_data_with_size(self, mock_load_dataset: MagicMock) -> None: | |
mock_load_dataset.return_value = self.mock_data | |
self.loader.load_data() | |
sampled_data = self.loader.load_data(size=2) | |
self.assertEqual(len(sampled_data), 2) | |
self.assertTrue(set(sampled_data).issubset({"prompt1", "prompt2", "prompt3"})) | |
def test_load_data_size_exceeds(self, mock_load_dataset: MagicMock) -> None: | |
mock_load_dataset.return_value = self.mock_data | |
self.loader.load_data() | |
with self.assertRaises(ValueError): | |
self.loader.load_data(size=10) | |
def test_data_loading_on_demand(self, mock_load_dataset: MagicMock) -> None: | |
mock_load_dataset.return_value = self.mock_data | |
mock_load_dataset.assert_not_called() | |
self.loader.load_data() | |
mock_load_dataset.assert_called_once() | |
def test_random_sampling(self, mock_load_dataset: MagicMock) -> None: | |
mock_load_dataset.return_value = self.mock_data | |
self.loader.load_data() | |
sample = self.loader.load_data(size=2) | |
self.assertEqual(len(sample), 2) | |
self.assertTrue(set(sample).issubset({"prompt1", "prompt2", "prompt3"})) | |
self.assertNotEqual(sample, ["prompt1", "prompt2"]) | |
if __name__ == "__main__": | |
unittest.main() | |