Lazar Radojevic commited on
Commit
da82b2b
·
1 Parent(s): 1cd5053

final version

Browse files
.gitignore ADDED
@@ -0,0 +1 @@
 
 
1
+ __pycache__
backend/routes.py CHANGED
@@ -7,6 +7,11 @@ from backend.models import QueryRequest, QueryResponse, SimilarPrompt
7
  from src.prompt_loader import PromptLoader
8
  from src.search_engine import PromptSearchEngine
9
 
 
 
 
 
 
10
  # Constants
11
  SEED = int(os.getenv("SEED", 42))
12
  DATASET_SIZE = int(os.getenv("DATASET_SIZE", 1000))
 
7
  from src.prompt_loader import PromptLoader
8
  from src.search_engine import PromptSearchEngine
9
 
10
+ from dotenv import load_dotenv
11
+
12
+ # Load environment variables from .env file
13
+ load_dotenv()
14
+
15
  # Constants
16
  SEED = int(os.getenv("SEED", 42))
17
  DATASET_SIZE = int(os.getenv("DATASET_SIZE", 1000))
frontend/app_ui.py CHANGED
@@ -3,6 +3,11 @@ import os
3
  import requests
4
  import streamlit as st
5
 
 
 
 
 
 
6
  # Read API URL from environment variable
7
  API_URL = os.getenv("API_URL", "http://localhost:8000")
8
 
 
3
  import requests
4
  import streamlit as st
5
 
6
+ from dotenv import load_dotenv
7
+
8
+ # Load environment variables from .env file
9
+ load_dotenv()
10
+
11
  # Read API URL from environment variable
12
  API_URL = os.getenv("API_URL", "http://localhost:8000")
13
 
poe/common-tasks.toml CHANGED
@@ -34,7 +34,7 @@ cmd = "ruff check ."
34
 
35
  [tool.poe.tasks.test]
36
  help = "Run unit tests"
37
- cmd = "pytest -p no:cacheprovider"
38
 
39
  [tool.poe.tasks.clean]
40
  help = "Remove automatically generated files"
 
34
 
35
  [tool.poe.tasks.test]
36
  help = "Run unit tests"
37
+ cmd = "python -m unittest discover -s tests"
38
 
39
  [tool.poe.tasks.clean]
40
  help = "Remove automatically generated files"
poetry.lock CHANGED
@@ -3718,4 +3718,4 @@ multidict = ">=4.0"
3718
  [metadata]
3719
  lock-version = "2.0"
3720
  python-versions = "^3.10"
3721
- content-hash = "38832b2f1f7e879f5efe88601e5ba8d8971bbbe8b4326625936762f860a7c128"
 
3718
  [metadata]
3719
  lock-version = "2.0"
3720
  python-versions = "^3.10"
3721
+ content-hash = "14c56c888e2fbf236863e1a06b7a2a42c79377dea1917f6d7387ed106713abfd"
pyproject.toml CHANGED
@@ -15,6 +15,7 @@ numpy = "1.26.4"
15
  fastapi = "^0.111.1"
16
  uvicorn = "^0.30.3"
17
  streamlit = "^1.37.0"
 
18
 
19
  [tool.poetry.group.dev.dependencies]
20
  black = "^24.1.1"
 
15
  fastapi = "^0.111.1"
16
  uvicorn = "^0.30.3"
17
  streamlit = "^1.37.0"
18
+ python-dotenv = "^1.0.1"
19
 
20
  [tool.poetry.group.dev.dependencies]
21
  black = "^24.1.1"
src/prompt_loader.py CHANGED
@@ -19,7 +19,7 @@ class PromptLoader:
19
  self.randomizer = random.Random(seed)
20
  self.data: Optional[List[str]] = None
21
 
22
- def _load_data(self) -> None:
23
  """
24
  Loads the dataset of prompts and stores them in the `data` attribute.
25
 
@@ -33,7 +33,7 @@ class PromptLoader:
33
  """
34
  Loads and samples prompts from the dataset.
35
 
36
- If the dataset is not already loaded, it calls `_load_data()` to load it.
37
 
38
  Args:
39
  size (Optional[int]): The number of prompts to sample. If not specified, all loaded prompts are returned.
@@ -46,7 +46,7 @@ class PromptLoader:
46
  ValueError: If `size` is specified and is greater than the number of available prompts.
47
  """
48
  if not self.data:
49
- self._load_data()
50
 
51
  if size:
52
  if size > len(self.data):
 
19
  self.randomizer = random.Random(seed)
20
  self.data: Optional[List[str]] = None
21
 
22
+ def _get_data(self) -> None:
23
  """
24
  Loads the dataset of prompts and stores them in the `data` attribute.
25
 
 
33
  """
34
  Loads and samples prompts from the dataset.
35
 
36
+ If the dataset is not already loaded, it calls `_get_data()` to load it.
37
 
38
  Args:
39
  size (Optional[int]): The number of prompts to sample. If not specified, all loaded prompts are returned.
 
46
  ValueError: If `size` is specified and is greater than the number of available prompts.
47
  """
48
  if not self.data:
49
+ self._get_data()
50
 
51
  if size:
52
  if size > len(self.data):
tests/test_load_data.py ADDED
@@ -0,0 +1,60 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import unittest
2
+ from unittest.mock import patch, MagicMock
3
+ from src.prompt_loader import (
4
+ PromptLoader,
5
+ )
6
+
7
+
8
+ class TestPromptLoader(unittest.TestCase):
9
+
10
+ def setUp(self) -> None:
11
+ # Set up a mock dataset for testing
12
+ self.mock_data = {"train": {"prompt": ["prompt1", "prompt2", "prompt3"]}}
13
+ self.loader = PromptLoader(seed=42)
14
+
15
+ @patch("src.prompt_loader.load_dataset")
16
+ def test_load_data_without_size(self, mock_load_dataset: MagicMock) -> None:
17
+ mock_load_dataset.return_value = self.mock_data
18
+
19
+ self.loader.load_data()
20
+ self.assertEqual(self.loader.data, ["prompt1", "prompt2", "prompt3"])
21
+
22
+ @patch("src.prompt_loader.load_dataset")
23
+ def test_load_data_with_size(self, mock_load_dataset: MagicMock) -> None:
24
+ mock_load_dataset.return_value = self.mock_data
25
+ self.loader.load_data()
26
+ sampled_data = self.loader.load_data(size=2)
27
+
28
+ self.assertEqual(len(sampled_data), 2)
29
+ self.assertTrue(set(sampled_data).issubset({"prompt1", "prompt2", "prompt3"}))
30
+
31
+ @patch("src.prompt_loader.load_dataset")
32
+ def test_load_data_size_exceeds(self, mock_load_dataset: MagicMock) -> None:
33
+ mock_load_dataset.return_value = self.mock_data
34
+ self.loader.load_data()
35
+
36
+ with self.assertRaises(ValueError):
37
+ self.loader.load_data(size=10)
38
+
39
+ @patch("src.prompt_loader.load_dataset")
40
+ def test_data_loading_on_demand(self, mock_load_dataset: MagicMock) -> None:
41
+ mock_load_dataset.return_value = self.mock_data
42
+ mock_load_dataset.assert_not_called()
43
+
44
+ self.loader.load_data()
45
+ mock_load_dataset.assert_called_once()
46
+
47
+ @patch("src.prompt_loader.load_dataset")
48
+ def test_random_sampling(self, mock_load_dataset: MagicMock) -> None:
49
+ mock_load_dataset.return_value = self.mock_data
50
+
51
+ self.loader.load_data()
52
+ sample = self.loader.load_data(size=2)
53
+
54
+ self.assertEqual(len(sample), 2)
55
+ self.assertTrue(set(sample).issubset({"prompt1", "prompt2", "prompt3"}))
56
+ self.assertNotEqual(sample, ["prompt1", "prompt2"])
57
+
58
+
59
+ if __name__ == "__main__":
60
+ unittest.main()
tests/test_similar_prompts.py ADDED
@@ -0,0 +1,39 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import unittest
2
+ from unittest.mock import patch, Mock
3
+ import requests
4
+
5
+ # Assuming the function to be tested is in a module named `frontend.app_ui`
6
+ from frontend.app_ui import get_similar_prompts
7
+
8
+
9
+ class TestGetSimilarPrompts(unittest.TestCase):
10
+
11
+ @patch("frontend.app_ui.requests.post")
12
+ def test_get_similar_prompts_success(self, mock_post):
13
+ # Mock the response object to simulate a successful API call
14
+ mock_response = Mock()
15
+ mock_response.status_code = 200
16
+ mock_response.json.return_value = {"prompts": ["prompt1", "prompt2", "prompt3"]}
17
+ mock_post.return_value = mock_response
18
+
19
+ # Call the function with a sample query and number
20
+ result = get_similar_prompts("test query", 3)
21
+
22
+ # Assertions
23
+ self.assertIsInstance(result, dict)
24
+ self.assertEqual(result, {"prompts": ["prompt1", "prompt2", "prompt3"]})
25
+
26
+ @patch("frontend.app_ui.requests.post")
27
+ def test_get_similar_prompts_failure(self, mock_post):
28
+ # Mock the response object to simulate a failed API call
29
+ mock_post.side_effect = requests.RequestException("Mock request exception")
30
+
31
+ # Call the function with a sample query and number
32
+ result = get_similar_prompts("test query", 3)
33
+
34
+ # Assertions
35
+ self.assertIsNone(result)
36
+
37
+
38
+ if __name__ == "__main__":
39
+ unittest.main()