Vitomir Jovanović commited on
Commit
01f5415
1 Parent(s): e6bc9b1

Search Engine

Browse files
api.py ADDED
@@ -0,0 +1,33 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # streamlit_app.py
2
+ import streamlit as st
3
+ import requests
4
+
5
+ # Streamlit app title
6
+ st.title("Top K Search with Vector DataBase")
7
+
8
+ # FastAPI endpoint URL
9
+ url = "http://127.0.0.1:8000/search/"
10
+
11
+ # Input fields in Streamlit
12
+ id = st.text_input("Enter ID:", value="1")
13
+ prompt = st.text_input("Enter your prompt:")
14
+ k = st.number_input("Top K results:", min_value=1, max_value=100, value=3)
15
+
16
+ # Trigger the search when the button is clicked
17
+ if st.button("Search"):
18
+ # Construct the request payload
19
+ payload = {
20
+ "id": id,
21
+ "prompt": prompt,
22
+ "k": k
23
+ }
24
+
25
+ # Make the POST request
26
+ response = requests.post(url, json=payload)
27
+
28
+ # Handle the response
29
+ if response.status_code == 200:
30
+ results = response.json()
31
+ st.write(results)
32
+ else:
33
+ st.error(f"Error: {response.status_code} - {response.text}")
environment.yaml ADDED
Binary file (4.77 kB). View file
 
main.py ADDED
@@ -0,0 +1,58 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import sys
2
+ import os
3
+ import copy
4
+ import uvicorn
5
+ import socket
6
+ import logging
7
+ import datetime
8
+ from models.vectorizer import Vectorizer
9
+ from models.prompt_search_engine import PromptSearchEngine
10
+ from models.data_reader import load_prompts_from_jsonl
11
+ from models.Query import Query, Query_Multiple, SearchResponse, SimilarPrompt
12
+ from decouple import config
13
+ from fastapi import FastAPI, HTTPException, Depends, Body
14
+ from sentence_transformers import SentenceTransformer
15
+
16
+
17
+
18
+ prompt_path = r"C:\Users\jov2bg\Desktop\PromptSearch\models\prompts_data.jsonl"
19
+
20
+
21
+ app = FastAPI(title="Search Prompt Engine", description="API for prompt search", version="1.0")
22
+
23
+ prompts = load_prompts_from_jsonl(prompt_path)
24
+ search_engine = PromptSearchEngine()
25
+ search_engine.add_prompts_to_vector_database(prompts)
26
+
27
+ @app.get("/")
28
+ def read_root():
29
+ return {"message": "Prompt Search Engine is running!"}
30
+
31
+ @app.post("/search/")
32
+ async def search_prompts(query: Query, k: int = 3):
33
+ print(f'Prompt: {query.prompt}')
34
+ similar_prompts, distances = search_engine.most_similar(query.prompt, top_k=k)
35
+ print(f'Similar Prompts {similar_prompts}')
36
+ print(f'Distances {distances}')
37
+ print(40*'****')
38
+ # Format the response
39
+ response = [
40
+ SimilarPrompt(prompt=prompt, distance=float(distance))
41
+ for prompt, distance in zip(similar_prompts, distances)
42
+ ]
43
+
44
+ return SearchResponse(results=response)
45
+
46
+ @app.post("/all_vectors_similarities/")
47
+ async def all_vectors(query: Query):
48
+
49
+ all_similarities = search_engine.cosine_similarity(query.prompt, search_engine.index)
50
+ response = [
51
+ SimilarPrompt(prompt=prompt, distance=float(distance))
52
+ for prompt, distance in all_similarities.items()
53
+ ]
54
+ return SearchResponse(results=response)
55
+
56
+ if __name__ == "__main__":
57
+ localhost = socket.gethostbyname("localhost")
58
+ uvicorn.run(app, host=localhost, port=8000)
models/Query.py ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from pydantic import BaseModel
2
+ from typing import List
3
+
4
+
5
+
6
+ class Query(BaseModel):
7
+
8
+ id: str
9
+ prompt: str
10
+
11
+ class Query_Multiple(BaseModel):
12
+ prompt: List[Query]
13
+
14
+
15
+ class SimilarPrompt(BaseModel):
16
+ prompt: str
17
+ distance: float
18
+
19
+ class SearchResponse(BaseModel):
20
+ results: List[SimilarPrompt]
models/__pycache__/Query.cpython-312.pyc ADDED
Binary file (1.11 kB). View file
 
models/__pycache__/data_reader.cpython-312.pyc ADDED
Binary file (2.35 kB). View file
 
models/__pycache__/prompt_search_engine.cpython-312.pyc ADDED
Binary file (3.53 kB). View file
 
models/__pycache__/vectorizer.cpython-312.pyc ADDED
Binary file (1.82 kB). View file
 
models/data_reader.py ADDED
@@ -0,0 +1,48 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from datasets import load_dataset
2
+ import json
3
+
4
+
5
+ # Load the dataset
6
+ base_url = "https://huggingface.co/datasets/jackyhate/text-to-image-2M/resolve/main/data_512_2M/data_{i:06d}.tar"
7
+ num_shards = 46 # Number of webdataset tar files
8
+
9
+
10
+
11
+ def download_data(base_url, num_shards):
12
+ # Download the data
13
+ urls = [base_url.format(i=i) for i in range(num_shards)]
14
+ dataset = load_dataset("webdataset", data_files={"train": urls}, split="train", streaming=True)
15
+ return dataset
16
+
17
+ def extract_prompts(dataset, json_file_path):
18
+ # Write data to the jsonl file
19
+ prompts = {}
20
+ with open(jsonl_file_path, 'w') as f:
21
+ for index, row in enumerate(dataset):
22
+ prompts[index] = row['json']['prompt']
23
+ f.write(json.dumps(prompts[index]) + '\n')
24
+
25
+
26
+ def read_data(jsonl_file_path):
27
+
28
+ # Read data from the jsonl file
29
+ with open(jsonl_file_path, 'r') as f:
30
+ for line in f:
31
+ row = json.loads(line)
32
+ print(row)
33
+
34
+ def load_prompts_from_jsonl(file_path):
35
+ prompts = []
36
+ with open(file_path, 'r') as f:
37
+ for line in f:
38
+ data = json.loads(line) # Each line is a JSON object
39
+ prompts.append(data) # Extract the 'prompt' field
40
+ return prompts
41
+
42
+
43
+ if __name__ == "__main__":
44
+ jsonl_file_path = r"C:\Users\jov2bg\Desktop\PromptSearch\models\prompts_data.jsonl"
45
+ num_shards = 1
46
+ dataset = download_data(num_shards, base_url)
47
+ extract_prompts(dataset, jsonl_file_path)
48
+ read_data(jsonl_file_path)
models/prompt_search_engine.py ADDED
@@ -0,0 +1,48 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Sequence, List, Tuple
2
+ from models.vectorizer import Vectorizer
3
+ import numpy as np
4
+ from sentence_transformers import SentenceTransformer
5
+ import faiss
6
+
7
+ class PromptSearchEngine:
8
+ def __init__(self, model_name='bert-base-nli-mean-tokens'):
9
+ self.model = SentenceTransformer(model_name)
10
+ # Initialize FAISS index with right number of dimensions
11
+ self.embedding_dimension = self.model.get_sentence_embedding_dimension()
12
+ self.index = faiss.IndexFlatL2(self.embedding_dimension) # Euclidian distance index - brute force for small datasets
13
+ self.prompts_track = [] # To keep track of original prompts for returning results
14
+
15
+
16
+ def add_prompts_to_vector_database(self, prompts):
17
+ embeddings = self.model.encode(prompts)
18
+ self.index.add(np.array(embeddings).astype('float32'))
19
+ self.prompts_track.extend(prompts)
20
+
21
+
22
+ def most_similar(self, query, top_k=5):
23
+ # Encode the query
24
+ query_embedding = self.model.encode([query]).astype('float32')
25
+
26
+ # Optimizovana pretraga ali moramo promeniti vrstu indeksa
27
+ distances, indices = self.index.search(query_embedding, top_k)
28
+
29
+ # Retrieve the corresponding prompts for the found indices
30
+ similar_prompts = [self.prompts_track[idx] for idx in indices[0]]
31
+
32
+ return similar_prompts, distances[0] # Return both the similar prompts and their distances
33
+
34
+
35
+ def cosine_similarity(query_vector: np.ndarray, corpus_vectors: np.ndarray) -> np.ndarray:
36
+ """Compute the cosine similarity between a query vector and a set of corpus vectors.
37
+ Args: query_vector: The query vector to compare against the corpus vectors. corpus_vectors: The set of corpus vectors to compare against the query vector.
38
+ Returns: The cosine similarity between the query vector and the corpus vectors.
39
+ """
40
+ similarities = {}
41
+ for index, vector in enumerate(corpus_vectors):
42
+ if np.linalg.norm(vector) == 0:
43
+ raise ValueError("One of the corpus vectors has zero norm.")
44
+ cos_similarity = np.dot(vector, query_vector) / (np.linalg.norm(vector) * np.linalg.norm(query_vector))
45
+ similarities[index] = cos_similarity
46
+ return similarities
47
+
48
+
models/prompts_data.jsonl ADDED
The diff for this file is too large to render. See raw diff
 
models/vectorizer.py ADDED
@@ -0,0 +1,33 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ from sentence_transformers import SentenceTransformer
3
+ import numpy as np
4
+ from typing import Sequence
5
+ import faiss
6
+
7
+
8
+
9
+
10
+
11
+ class Vectorizer:
12
+ def __init__(self, model) -> None:
13
+ """Initialize the vectorizer with a pre-trained embedding model.
14
+ Args: model: The pre-trained embedding model to use for transforming prompts.
15
+ """
16
+ self.model = model
17
+ self.index_size = 50000
18
+ self.index = faiss.IndexFlatIP(self.index_size)
19
+ self.cached_index_idx_to_retrieval_db_idx = []
20
+
21
+
22
+ def transform_and_add_to_index(self, prompts: Sequence[str]) -> np.ndarray:
23
+ """Transform texts into numerical vectors using the specified model.
24
+ Args: prompts: The sequence of raw corpus prompts. Returns: Vectorized prompts
25
+ """
26
+ embeddings = self.model.encode(prompts)
27
+ embedding_dimension = embeddings.shape[1]
28
+ print('Embedding dimension:', embedding_dimension)
29
+
30
+ self.index.add(np.array(embeddings))
31
+
32
+
33
+
requirements.txt ADDED
Binary file (5.06 kB). View file