Spaces:
Sleeping
Sleeping
Lazar Radojevic
commited on
Commit
·
da82b2b
1
Parent(s):
1cd5053
final version
Browse files- .gitignore +1 -0
- backend/routes.py +5 -0
- frontend/app_ui.py +5 -0
- poe/common-tasks.toml +1 -1
- poetry.lock +1 -1
- pyproject.toml +1 -0
- src/prompt_loader.py +3 -3
- tests/test_load_data.py +60 -0
- tests/test_similar_prompts.py +39 -0
.gitignore
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
__pycache__
|
backend/routes.py
CHANGED
|
@@ -7,6 +7,11 @@ from backend.models import QueryRequest, QueryResponse, SimilarPrompt
|
|
| 7 |
from src.prompt_loader import PromptLoader
|
| 8 |
from src.search_engine import PromptSearchEngine
|
| 9 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 10 |
# Constants
|
| 11 |
SEED = int(os.getenv("SEED", 42))
|
| 12 |
DATASET_SIZE = int(os.getenv("DATASET_SIZE", 1000))
|
|
|
|
| 7 |
from src.prompt_loader import PromptLoader
|
| 8 |
from src.search_engine import PromptSearchEngine
|
| 9 |
|
| 10 |
+
from dotenv import load_dotenv
|
| 11 |
+
|
| 12 |
+
# Load environment variables from .env file
|
| 13 |
+
load_dotenv()
|
| 14 |
+
|
| 15 |
# Constants
|
| 16 |
SEED = int(os.getenv("SEED", 42))
|
| 17 |
DATASET_SIZE = int(os.getenv("DATASET_SIZE", 1000))
|
frontend/app_ui.py
CHANGED
|
@@ -3,6 +3,11 @@ import os
|
|
| 3 |
import requests
|
| 4 |
import streamlit as st
|
| 5 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 6 |
# Read API URL from environment variable
|
| 7 |
API_URL = os.getenv("API_URL", "http://localhost:8000")
|
| 8 |
|
|
|
|
| 3 |
import requests
|
| 4 |
import streamlit as st
|
| 5 |
|
| 6 |
+
from dotenv import load_dotenv
|
| 7 |
+
|
| 8 |
+
# Load environment variables from .env file
|
| 9 |
+
load_dotenv()
|
| 10 |
+
|
| 11 |
# Read API URL from environment variable
|
| 12 |
API_URL = os.getenv("API_URL", "http://localhost:8000")
|
| 13 |
|
poe/common-tasks.toml
CHANGED
|
@@ -34,7 +34,7 @@ cmd = "ruff check ."
|
|
| 34 |
|
| 35 |
[tool.poe.tasks.test]
|
| 36 |
help = "Run unit tests"
|
| 37 |
-
cmd = "
|
| 38 |
|
| 39 |
[tool.poe.tasks.clean]
|
| 40 |
help = "Remove automatically generated files"
|
|
|
|
| 34 |
|
| 35 |
[tool.poe.tasks.test]
|
| 36 |
help = "Run unit tests"
|
| 37 |
+
cmd = "python -m unittest discover -s tests"
|
| 38 |
|
| 39 |
[tool.poe.tasks.clean]
|
| 40 |
help = "Remove automatically generated files"
|
poetry.lock
CHANGED
|
@@ -3718,4 +3718,4 @@ multidict = ">=4.0"
|
|
| 3718 |
[metadata]
|
| 3719 |
lock-version = "2.0"
|
| 3720 |
python-versions = "^3.10"
|
| 3721 |
-
content-hash = "
|
|
|
|
| 3718 |
[metadata]
|
| 3719 |
lock-version = "2.0"
|
| 3720 |
python-versions = "^3.10"
|
| 3721 |
+
content-hash = "14c56c888e2fbf236863e1a06b7a2a42c79377dea1917f6d7387ed106713abfd"
|
pyproject.toml
CHANGED
|
@@ -15,6 +15,7 @@ numpy = "1.26.4"
|
|
| 15 |
fastapi = "^0.111.1"
|
| 16 |
uvicorn = "^0.30.3"
|
| 17 |
streamlit = "^1.37.0"
|
|
|
|
| 18 |
|
| 19 |
[tool.poetry.group.dev.dependencies]
|
| 20 |
black = "^24.1.1"
|
|
|
|
| 15 |
fastapi = "^0.111.1"
|
| 16 |
uvicorn = "^0.30.3"
|
| 17 |
streamlit = "^1.37.0"
|
| 18 |
+
python-dotenv = "^1.0.1"
|
| 19 |
|
| 20 |
[tool.poetry.group.dev.dependencies]
|
| 21 |
black = "^24.1.1"
|
src/prompt_loader.py
CHANGED
|
@@ -19,7 +19,7 @@ class PromptLoader:
|
|
| 19 |
self.randomizer = random.Random(seed)
|
| 20 |
self.data: Optional[List[str]] = None
|
| 21 |
|
| 22 |
-
def
|
| 23 |
"""
|
| 24 |
Loads the dataset of prompts and stores them in the `data` attribute.
|
| 25 |
|
|
@@ -33,7 +33,7 @@ class PromptLoader:
|
|
| 33 |
"""
|
| 34 |
Loads and samples prompts from the dataset.
|
| 35 |
|
| 36 |
-
If the dataset is not already loaded, it calls `
|
| 37 |
|
| 38 |
Args:
|
| 39 |
size (Optional[int]): The number of prompts to sample. If not specified, all loaded prompts are returned.
|
|
@@ -46,7 +46,7 @@ class PromptLoader:
|
|
| 46 |
ValueError: If `size` is specified and is greater than the number of available prompts.
|
| 47 |
"""
|
| 48 |
if not self.data:
|
| 49 |
-
self.
|
| 50 |
|
| 51 |
if size:
|
| 52 |
if size > len(self.data):
|
|
|
|
| 19 |
self.randomizer = random.Random(seed)
|
| 20 |
self.data: Optional[List[str]] = None
|
| 21 |
|
| 22 |
+
def _get_data(self) -> None:
|
| 23 |
"""
|
| 24 |
Loads the dataset of prompts and stores them in the `data` attribute.
|
| 25 |
|
|
|
|
| 33 |
"""
|
| 34 |
Loads and samples prompts from the dataset.
|
| 35 |
|
| 36 |
+
If the dataset is not already loaded, it calls `_get_data()` to load it.
|
| 37 |
|
| 38 |
Args:
|
| 39 |
size (Optional[int]): The number of prompts to sample. If not specified, all loaded prompts are returned.
|
|
|
|
| 46 |
ValueError: If `size` is specified and is greater than the number of available prompts.
|
| 47 |
"""
|
| 48 |
if not self.data:
|
| 49 |
+
self._get_data()
|
| 50 |
|
| 51 |
if size:
|
| 52 |
if size > len(self.data):
|
tests/test_load_data.py
ADDED
|
@@ -0,0 +1,60 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import unittest
|
| 2 |
+
from unittest.mock import patch, MagicMock
|
| 3 |
+
from src.prompt_loader import (
|
| 4 |
+
PromptLoader,
|
| 5 |
+
)
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
class TestPromptLoader(unittest.TestCase):
|
| 9 |
+
|
| 10 |
+
def setUp(self) -> None:
|
| 11 |
+
# Set up a mock dataset for testing
|
| 12 |
+
self.mock_data = {"train": {"prompt": ["prompt1", "prompt2", "prompt3"]}}
|
| 13 |
+
self.loader = PromptLoader(seed=42)
|
| 14 |
+
|
| 15 |
+
@patch("src.prompt_loader.load_dataset")
|
| 16 |
+
def test_load_data_without_size(self, mock_load_dataset: MagicMock) -> None:
|
| 17 |
+
mock_load_dataset.return_value = self.mock_data
|
| 18 |
+
|
| 19 |
+
self.loader.load_data()
|
| 20 |
+
self.assertEqual(self.loader.data, ["prompt1", "prompt2", "prompt3"])
|
| 21 |
+
|
| 22 |
+
@patch("src.prompt_loader.load_dataset")
|
| 23 |
+
def test_load_data_with_size(self, mock_load_dataset: MagicMock) -> None:
|
| 24 |
+
mock_load_dataset.return_value = self.mock_data
|
| 25 |
+
self.loader.load_data()
|
| 26 |
+
sampled_data = self.loader.load_data(size=2)
|
| 27 |
+
|
| 28 |
+
self.assertEqual(len(sampled_data), 2)
|
| 29 |
+
self.assertTrue(set(sampled_data).issubset({"prompt1", "prompt2", "prompt3"}))
|
| 30 |
+
|
| 31 |
+
@patch("src.prompt_loader.load_dataset")
|
| 32 |
+
def test_load_data_size_exceeds(self, mock_load_dataset: MagicMock) -> None:
|
| 33 |
+
mock_load_dataset.return_value = self.mock_data
|
| 34 |
+
self.loader.load_data()
|
| 35 |
+
|
| 36 |
+
with self.assertRaises(ValueError):
|
| 37 |
+
self.loader.load_data(size=10)
|
| 38 |
+
|
| 39 |
+
@patch("src.prompt_loader.load_dataset")
|
| 40 |
+
def test_data_loading_on_demand(self, mock_load_dataset: MagicMock) -> None:
|
| 41 |
+
mock_load_dataset.return_value = self.mock_data
|
| 42 |
+
mock_load_dataset.assert_not_called()
|
| 43 |
+
|
| 44 |
+
self.loader.load_data()
|
| 45 |
+
mock_load_dataset.assert_called_once()
|
| 46 |
+
|
| 47 |
+
@patch("src.prompt_loader.load_dataset")
|
| 48 |
+
def test_random_sampling(self, mock_load_dataset: MagicMock) -> None:
|
| 49 |
+
mock_load_dataset.return_value = self.mock_data
|
| 50 |
+
|
| 51 |
+
self.loader.load_data()
|
| 52 |
+
sample = self.loader.load_data(size=2)
|
| 53 |
+
|
| 54 |
+
self.assertEqual(len(sample), 2)
|
| 55 |
+
self.assertTrue(set(sample).issubset({"prompt1", "prompt2", "prompt3"}))
|
| 56 |
+
self.assertNotEqual(sample, ["prompt1", "prompt2"])
|
| 57 |
+
|
| 58 |
+
|
| 59 |
+
if __name__ == "__main__":
|
| 60 |
+
unittest.main()
|
tests/test_similar_prompts.py
ADDED
|
@@ -0,0 +1,39 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import unittest
|
| 2 |
+
from unittest.mock import patch, Mock
|
| 3 |
+
import requests
|
| 4 |
+
|
| 5 |
+
# Assuming the function to be tested is in a module named `frontend.app_ui`
|
| 6 |
+
from frontend.app_ui import get_similar_prompts
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
class TestGetSimilarPrompts(unittest.TestCase):
|
| 10 |
+
|
| 11 |
+
@patch("frontend.app_ui.requests.post")
|
| 12 |
+
def test_get_similar_prompts_success(self, mock_post):
|
| 13 |
+
# Mock the response object to simulate a successful API call
|
| 14 |
+
mock_response = Mock()
|
| 15 |
+
mock_response.status_code = 200
|
| 16 |
+
mock_response.json.return_value = {"prompts": ["prompt1", "prompt2", "prompt3"]}
|
| 17 |
+
mock_post.return_value = mock_response
|
| 18 |
+
|
| 19 |
+
# Call the function with a sample query and number
|
| 20 |
+
result = get_similar_prompts("test query", 3)
|
| 21 |
+
|
| 22 |
+
# Assertions
|
| 23 |
+
self.assertIsInstance(result, dict)
|
| 24 |
+
self.assertEqual(result, {"prompts": ["prompt1", "prompt2", "prompt3"]})
|
| 25 |
+
|
| 26 |
+
@patch("frontend.app_ui.requests.post")
|
| 27 |
+
def test_get_similar_prompts_failure(self, mock_post):
|
| 28 |
+
# Mock the response object to simulate a failed API call
|
| 29 |
+
mock_post.side_effect = requests.RequestException("Mock request exception")
|
| 30 |
+
|
| 31 |
+
# Call the function with a sample query and number
|
| 32 |
+
result = get_similar_prompts("test query", 3)
|
| 33 |
+
|
| 34 |
+
# Assertions
|
| 35 |
+
self.assertIsNone(result)
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
if __name__ == "__main__":
|
| 39 |
+
unittest.main()
|