File size: 3,064 Bytes
1cd5053
3556e6f
0b497e7
1cd5053
45462fb
3556e6f
1cd5053
7af929b
3556e6f
7af929b
da82b2b
 
 
7af929b
1cd5053
 
7af929b
 
1cd5053
7af929b
 
1cd5053
 
7af929b
 
1cd5053
 
 
 
7af929b
1cd5053
 
7af929b
1cd5053
 
7af929b
1cd5053
 
 
7af929b
 
 
 
 
 
 
 
 
 
 
 
be043a6
45462fb
 
1cd5053
 
 
 
 
 
 
 
45462fb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
0b497e7
45462fb
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
import os

from dotenv import load_dotenv
from fastapi import APIRouter, HTTPException
from fastapi.responses import HTMLResponse

from backend.models import QueryRequest, QueryResponse, SimilarPrompt
from src.prompt_loader import PromptLoader
from src.search_engine import PromptSearchEngine

# Load environment variables from .env file
load_dotenv()

# Constants
SEED = int(os.getenv("SEED", 42))
DATASET_SIZE = int(os.getenv("DATASET_SIZE", 1000))

# Initialize the prompt loader and search engine
prompts = PromptLoader(seed=SEED).load_data(size=DATASET_SIZE)
engine = PromptSearchEngine(prompts)

# Initialize the API router
router = APIRouter()


@router.post("/most_similar", response_model=QueryResponse)
async def get_most_similar(query_request: QueryRequest) -> QueryResponse:
    """
    Endpoint to retrieve the most similar prompts based on a user query.

    Args:
    query_request (QueryRequest): The request payload containing the user query and the number of similar prompts to retrieve.

    Returns:
    QueryResponse: A response containing a list of similar prompts and their similarity scores.

    Raises:
    HTTPException: If an internal server error occurs while processing the request.
    """
    try:
        similar_prompts = engine.most_similar(
            query=query_request.query, n=query_request.n
        )
        response = QueryResponse(
            similar_prompts=[
                SimilarPrompt(score=score, prompt=prompt)
                for score, prompt in similar_prompts
            ]
        )
        return response
    except Exception as e:
        raise HTTPException(status_code=500, detail=str(e))


@router.get("/", response_class=HTMLResponse)
async def home_page() -> HTMLResponse:
    """
    Endpoint to serve a simple HTML page with information about the API.

    Returns:
    HTMLResponse: An HTML page providing an overview of the API and how to use it.
    """
    return HTMLResponse(
        """
        <!DOCTYPE html>
        <html lang="en">
        <head>
            <meta charset="UTF-8">
            <meta name="viewport" content="width=device-width, initial-scale=1.0">
            <title>Prompt Search Engine</title>
            <style>
                body { font-family: Arial, sans-serif; margin: 20px; }
                h1 { color: #333; }
                p { margin-bottom: 10px; }
                code { background: #f4f4f4; padding: 2px 4px; border-radius: 4px; }
                .container { max-width: 800px; margin: 0 auto; }
            </style>
        </head>
        <body>
            <div class="container">
                <h1>Prompt Search Engine API</h1>
                <p>Use this API to find similar prompts based on a query.</p>
                <h2>POST /most_similar</h2>
                <p><strong>Request:</strong> <code>{"query": "string", "n": 1}</code></p>
                <p><strong>Response:</strong> <code>{"similar_prompts": [{"score": 0.95, "prompt": "Example prompt 1"}]}</code></p>
            </div>
        </body>
        </html>
        """
    )