Spaces:
Sleeping
Sleeping
File size: 3,064 Bytes
1cd5053 3556e6f 0b497e7 1cd5053 45462fb 3556e6f 1cd5053 7af929b 3556e6f 7af929b da82b2b 7af929b 1cd5053 7af929b 1cd5053 7af929b 1cd5053 7af929b 1cd5053 7af929b 1cd5053 7af929b 1cd5053 7af929b 1cd5053 7af929b be043a6 45462fb 1cd5053 45462fb 0b497e7 45462fb |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 |
import os
from dotenv import load_dotenv
from fastapi import APIRouter, HTTPException
from fastapi.responses import HTMLResponse
from backend.models import QueryRequest, QueryResponse, SimilarPrompt
from src.prompt_loader import PromptLoader
from src.search_engine import PromptSearchEngine
# Load environment variables from .env file
load_dotenv()
# Constants
SEED = int(os.getenv("SEED", 42))
DATASET_SIZE = int(os.getenv("DATASET_SIZE", 1000))
# Initialize the prompt loader and search engine
prompts = PromptLoader(seed=SEED).load_data(size=DATASET_SIZE)
engine = PromptSearchEngine(prompts)
# Initialize the API router
router = APIRouter()
@router.post("/most_similar", response_model=QueryResponse)
async def get_most_similar(query_request: QueryRequest) -> QueryResponse:
"""
Endpoint to retrieve the most similar prompts based on a user query.
Args:
query_request (QueryRequest): The request payload containing the user query and the number of similar prompts to retrieve.
Returns:
QueryResponse: A response containing a list of similar prompts and their similarity scores.
Raises:
HTTPException: If an internal server error occurs while processing the request.
"""
try:
similar_prompts = engine.most_similar(
query=query_request.query, n=query_request.n
)
response = QueryResponse(
similar_prompts=[
SimilarPrompt(score=score, prompt=prompt)
for score, prompt in similar_prompts
]
)
return response
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
@router.get("/", response_class=HTMLResponse)
async def home_page() -> HTMLResponse:
"""
Endpoint to serve a simple HTML page with information about the API.
Returns:
HTMLResponse: An HTML page providing an overview of the API and how to use it.
"""
return HTMLResponse(
"""
<!DOCTYPE html>
<html lang="en">
<head>
<meta charset="UTF-8">
<meta name="viewport" content="width=device-width, initial-scale=1.0">
<title>Prompt Search Engine</title>
<style>
body { font-family: Arial, sans-serif; margin: 20px; }
h1 { color: #333; }
p { margin-bottom: 10px; }
code { background: #f4f4f4; padding: 2px 4px; border-radius: 4px; }
.container { max-width: 800px; margin: 0 auto; }
</style>
</head>
<body>
<div class="container">
<h1>Prompt Search Engine API</h1>
<p>Use this API to find similar prompts based on a query.</p>
<h2>POST /most_similar</h2>
<p><strong>Request:</strong> <code>{"query": "string", "n": 1}</code></p>
<p><strong>Response:</strong> <code>{"similar_prompts": [{"score": 0.95, "prompt": "Example prompt 1"}]}</code></p>
</div>
</body>
</html>
"""
)
|