prompt-engine / tests /test_load_data.py
Lazar Radojevic
add poe tasks
0b497e7
import unittest
from unittest.mock import MagicMock, patch
from src.prompt_loader import PromptLoader
class TestPromptLoader(unittest.TestCase):
def setUp(self) -> None:
# Set up a mock dataset for testing
self.mock_data = {"train": {"prompt": ["prompt1", "prompt2", "prompt3"]}}
self.loader = PromptLoader(seed=42)
@patch("src.prompt_loader.load_dataset")
def test_load_data_without_size(self, mock_load_dataset: MagicMock) -> None:
mock_load_dataset.return_value = self.mock_data
self.loader.load_data()
self.assertEqual(self.loader.data, ["prompt1", "prompt2", "prompt3"])
@patch("src.prompt_loader.load_dataset")
def test_load_data_with_size(self, mock_load_dataset: MagicMock) -> None:
mock_load_dataset.return_value = self.mock_data
self.loader.load_data()
sampled_data = self.loader.load_data(size=2)
self.assertEqual(len(sampled_data), 2)
self.assertTrue(set(sampled_data).issubset({"prompt1", "prompt2", "prompt3"}))
@patch("src.prompt_loader.load_dataset")
def test_load_data_size_exceeds(self, mock_load_dataset: MagicMock) -> None:
mock_load_dataset.return_value = self.mock_data
self.loader.load_data()
with self.assertRaises(ValueError):
self.loader.load_data(size=10)
@patch("src.prompt_loader.load_dataset")
def test_data_loading_on_demand(self, mock_load_dataset: MagicMock) -> None:
mock_load_dataset.return_value = self.mock_data
mock_load_dataset.assert_not_called()
self.loader.load_data()
mock_load_dataset.assert_called_once()
@patch("src.prompt_loader.load_dataset")
def test_random_sampling(self, mock_load_dataset: MagicMock) -> None:
mock_load_dataset.return_value = self.mock_data
self.loader.load_data()
sample = self.loader.load_data(size=2)
self.assertEqual(len(sample), 2)
self.assertTrue(set(sample).issubset({"prompt1", "prompt2", "prompt3"}))
self.assertNotEqual(sample, ["prompt1", "prompt2"])
if __name__ == "__main__":
unittest.main()