Spaces:
Sleeping
Sleeping
Lazar Radojevic
commited on
Commit
·
da82b2b
1
Parent(s):
1cd5053
final version
Browse files- .gitignore +1 -0
- backend/routes.py +5 -0
- frontend/app_ui.py +5 -0
- poe/common-tasks.toml +1 -1
- poetry.lock +1 -1
- pyproject.toml +1 -0
- src/prompt_loader.py +3 -3
- tests/test_load_data.py +60 -0
- tests/test_similar_prompts.py +39 -0
.gitignore
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
__pycache__
|
backend/routes.py
CHANGED
@@ -7,6 +7,11 @@ from backend.models import QueryRequest, QueryResponse, SimilarPrompt
|
|
7 |
from src.prompt_loader import PromptLoader
|
8 |
from src.search_engine import PromptSearchEngine
|
9 |
|
|
|
|
|
|
|
|
|
|
|
10 |
# Constants
|
11 |
SEED = int(os.getenv("SEED", 42))
|
12 |
DATASET_SIZE = int(os.getenv("DATASET_SIZE", 1000))
|
|
|
7 |
from src.prompt_loader import PromptLoader
|
8 |
from src.search_engine import PromptSearchEngine
|
9 |
|
10 |
+
from dotenv import load_dotenv
|
11 |
+
|
12 |
+
# Load environment variables from .env file
|
13 |
+
load_dotenv()
|
14 |
+
|
15 |
# Constants
|
16 |
SEED = int(os.getenv("SEED", 42))
|
17 |
DATASET_SIZE = int(os.getenv("DATASET_SIZE", 1000))
|
frontend/app_ui.py
CHANGED
@@ -3,6 +3,11 @@ import os
|
|
3 |
import requests
|
4 |
import streamlit as st
|
5 |
|
|
|
|
|
|
|
|
|
|
|
6 |
# Read API URL from environment variable
|
7 |
API_URL = os.getenv("API_URL", "http://localhost:8000")
|
8 |
|
|
|
3 |
import requests
|
4 |
import streamlit as st
|
5 |
|
6 |
+
from dotenv import load_dotenv
|
7 |
+
|
8 |
+
# Load environment variables from .env file
|
9 |
+
load_dotenv()
|
10 |
+
|
11 |
# Read API URL from environment variable
|
12 |
API_URL = os.getenv("API_URL", "http://localhost:8000")
|
13 |
|
poe/common-tasks.toml
CHANGED
@@ -34,7 +34,7 @@ cmd = "ruff check ."
|
|
34 |
|
35 |
[tool.poe.tasks.test]
|
36 |
help = "Run unit tests"
|
37 |
-
cmd = "
|
38 |
|
39 |
[tool.poe.tasks.clean]
|
40 |
help = "Remove automatically generated files"
|
|
|
34 |
|
35 |
[tool.poe.tasks.test]
|
36 |
help = "Run unit tests"
|
37 |
+
cmd = "python -m unittest discover -s tests"
|
38 |
|
39 |
[tool.poe.tasks.clean]
|
40 |
help = "Remove automatically generated files"
|
poetry.lock
CHANGED
@@ -3718,4 +3718,4 @@ multidict = ">=4.0"
|
|
3718 |
[metadata]
|
3719 |
lock-version = "2.0"
|
3720 |
python-versions = "^3.10"
|
3721 |
-
content-hash = "
|
|
|
3718 |
[metadata]
|
3719 |
lock-version = "2.0"
|
3720 |
python-versions = "^3.10"
|
3721 |
+
content-hash = "14c56c888e2fbf236863e1a06b7a2a42c79377dea1917f6d7387ed106713abfd"
|
pyproject.toml
CHANGED
@@ -15,6 +15,7 @@ numpy = "1.26.4"
|
|
15 |
fastapi = "^0.111.1"
|
16 |
uvicorn = "^0.30.3"
|
17 |
streamlit = "^1.37.0"
|
|
|
18 |
|
19 |
[tool.poetry.group.dev.dependencies]
|
20 |
black = "^24.1.1"
|
|
|
15 |
fastapi = "^0.111.1"
|
16 |
uvicorn = "^0.30.3"
|
17 |
streamlit = "^1.37.0"
|
18 |
+
python-dotenv = "^1.0.1"
|
19 |
|
20 |
[tool.poetry.group.dev.dependencies]
|
21 |
black = "^24.1.1"
|
src/prompt_loader.py
CHANGED
@@ -19,7 +19,7 @@ class PromptLoader:
|
|
19 |
self.randomizer = random.Random(seed)
|
20 |
self.data: Optional[List[str]] = None
|
21 |
|
22 |
-
def
|
23 |
"""
|
24 |
Loads the dataset of prompts and stores them in the `data` attribute.
|
25 |
|
@@ -33,7 +33,7 @@ class PromptLoader:
|
|
33 |
"""
|
34 |
Loads and samples prompts from the dataset.
|
35 |
|
36 |
-
If the dataset is not already loaded, it calls `
|
37 |
|
38 |
Args:
|
39 |
size (Optional[int]): The number of prompts to sample. If not specified, all loaded prompts are returned.
|
@@ -46,7 +46,7 @@ class PromptLoader:
|
|
46 |
ValueError: If `size` is specified and is greater than the number of available prompts.
|
47 |
"""
|
48 |
if not self.data:
|
49 |
-
self.
|
50 |
|
51 |
if size:
|
52 |
if size > len(self.data):
|
|
|
19 |
self.randomizer = random.Random(seed)
|
20 |
self.data: Optional[List[str]] = None
|
21 |
|
22 |
+
def _get_data(self) -> None:
|
23 |
"""
|
24 |
Loads the dataset of prompts and stores them in the `data` attribute.
|
25 |
|
|
|
33 |
"""
|
34 |
Loads and samples prompts from the dataset.
|
35 |
|
36 |
+
If the dataset is not already loaded, it calls `_get_data()` to load it.
|
37 |
|
38 |
Args:
|
39 |
size (Optional[int]): The number of prompts to sample. If not specified, all loaded prompts are returned.
|
|
|
46 |
ValueError: If `size` is specified and is greater than the number of available prompts.
|
47 |
"""
|
48 |
if not self.data:
|
49 |
+
self._get_data()
|
50 |
|
51 |
if size:
|
52 |
if size > len(self.data):
|
tests/test_load_data.py
ADDED
@@ -0,0 +1,60 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import unittest
|
2 |
+
from unittest.mock import patch, MagicMock
|
3 |
+
from src.prompt_loader import (
|
4 |
+
PromptLoader,
|
5 |
+
)
|
6 |
+
|
7 |
+
|
8 |
+
class TestPromptLoader(unittest.TestCase):
|
9 |
+
|
10 |
+
def setUp(self) -> None:
|
11 |
+
# Set up a mock dataset for testing
|
12 |
+
self.mock_data = {"train": {"prompt": ["prompt1", "prompt2", "prompt3"]}}
|
13 |
+
self.loader = PromptLoader(seed=42)
|
14 |
+
|
15 |
+
@patch("src.prompt_loader.load_dataset")
|
16 |
+
def test_load_data_without_size(self, mock_load_dataset: MagicMock) -> None:
|
17 |
+
mock_load_dataset.return_value = self.mock_data
|
18 |
+
|
19 |
+
self.loader.load_data()
|
20 |
+
self.assertEqual(self.loader.data, ["prompt1", "prompt2", "prompt3"])
|
21 |
+
|
22 |
+
@patch("src.prompt_loader.load_dataset")
|
23 |
+
def test_load_data_with_size(self, mock_load_dataset: MagicMock) -> None:
|
24 |
+
mock_load_dataset.return_value = self.mock_data
|
25 |
+
self.loader.load_data()
|
26 |
+
sampled_data = self.loader.load_data(size=2)
|
27 |
+
|
28 |
+
self.assertEqual(len(sampled_data), 2)
|
29 |
+
self.assertTrue(set(sampled_data).issubset({"prompt1", "prompt2", "prompt3"}))
|
30 |
+
|
31 |
+
@patch("src.prompt_loader.load_dataset")
|
32 |
+
def test_load_data_size_exceeds(self, mock_load_dataset: MagicMock) -> None:
|
33 |
+
mock_load_dataset.return_value = self.mock_data
|
34 |
+
self.loader.load_data()
|
35 |
+
|
36 |
+
with self.assertRaises(ValueError):
|
37 |
+
self.loader.load_data(size=10)
|
38 |
+
|
39 |
+
@patch("src.prompt_loader.load_dataset")
|
40 |
+
def test_data_loading_on_demand(self, mock_load_dataset: MagicMock) -> None:
|
41 |
+
mock_load_dataset.return_value = self.mock_data
|
42 |
+
mock_load_dataset.assert_not_called()
|
43 |
+
|
44 |
+
self.loader.load_data()
|
45 |
+
mock_load_dataset.assert_called_once()
|
46 |
+
|
47 |
+
@patch("src.prompt_loader.load_dataset")
|
48 |
+
def test_random_sampling(self, mock_load_dataset: MagicMock) -> None:
|
49 |
+
mock_load_dataset.return_value = self.mock_data
|
50 |
+
|
51 |
+
self.loader.load_data()
|
52 |
+
sample = self.loader.load_data(size=2)
|
53 |
+
|
54 |
+
self.assertEqual(len(sample), 2)
|
55 |
+
self.assertTrue(set(sample).issubset({"prompt1", "prompt2", "prompt3"}))
|
56 |
+
self.assertNotEqual(sample, ["prompt1", "prompt2"])
|
57 |
+
|
58 |
+
|
59 |
+
if __name__ == "__main__":
|
60 |
+
unittest.main()
|
tests/test_similar_prompts.py
ADDED
@@ -0,0 +1,39 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import unittest
|
2 |
+
from unittest.mock import patch, Mock
|
3 |
+
import requests
|
4 |
+
|
5 |
+
# Assuming the function to be tested is in a module named `frontend.app_ui`
|
6 |
+
from frontend.app_ui import get_similar_prompts
|
7 |
+
|
8 |
+
|
9 |
+
class TestGetSimilarPrompts(unittest.TestCase):
|
10 |
+
|
11 |
+
@patch("frontend.app_ui.requests.post")
|
12 |
+
def test_get_similar_prompts_success(self, mock_post):
|
13 |
+
# Mock the response object to simulate a successful API call
|
14 |
+
mock_response = Mock()
|
15 |
+
mock_response.status_code = 200
|
16 |
+
mock_response.json.return_value = {"prompts": ["prompt1", "prompt2", "prompt3"]}
|
17 |
+
mock_post.return_value = mock_response
|
18 |
+
|
19 |
+
# Call the function with a sample query and number
|
20 |
+
result = get_similar_prompts("test query", 3)
|
21 |
+
|
22 |
+
# Assertions
|
23 |
+
self.assertIsInstance(result, dict)
|
24 |
+
self.assertEqual(result, {"prompts": ["prompt1", "prompt2", "prompt3"]})
|
25 |
+
|
26 |
+
@patch("frontend.app_ui.requests.post")
|
27 |
+
def test_get_similar_prompts_failure(self, mock_post):
|
28 |
+
# Mock the response object to simulate a failed API call
|
29 |
+
mock_post.side_effect = requests.RequestException("Mock request exception")
|
30 |
+
|
31 |
+
# Call the function with a sample query and number
|
32 |
+
result = get_similar_prompts("test query", 3)
|
33 |
+
|
34 |
+
# Assertions
|
35 |
+
self.assertIsNone(result)
|
36 |
+
|
37 |
+
|
38 |
+
if __name__ == "__main__":
|
39 |
+
unittest.main()
|