Spaces:
Runtime error
Runtime error
Vitomir Jovanović
commited on
Commit
•
01f5415
1
Parent(s):
e6bc9b1
Search Engine
Browse files- api.py +33 -0
- environment.yaml +0 -0
- main.py +58 -0
- models/Query.py +20 -0
- models/__pycache__/Query.cpython-312.pyc +0 -0
- models/__pycache__/data_reader.cpython-312.pyc +0 -0
- models/__pycache__/prompt_search_engine.cpython-312.pyc +0 -0
- models/__pycache__/vectorizer.cpython-312.pyc +0 -0
- models/data_reader.py +48 -0
- models/prompt_search_engine.py +48 -0
- models/prompts_data.jsonl +0 -0
- models/vectorizer.py +33 -0
- requirements.txt +0 -0
api.py
ADDED
@@ -0,0 +1,33 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# streamlit_app.py
|
2 |
+
import streamlit as st
|
3 |
+
import requests
|
4 |
+
|
5 |
+
# Streamlit app title
|
6 |
+
st.title("Top K Search with Vector DataBase")
|
7 |
+
|
8 |
+
# FastAPI endpoint URL
|
9 |
+
url = "http://127.0.0.1:8000/search/"
|
10 |
+
|
11 |
+
# Input fields in Streamlit
|
12 |
+
id = st.text_input("Enter ID:", value="1")
|
13 |
+
prompt = st.text_input("Enter your prompt:")
|
14 |
+
k = st.number_input("Top K results:", min_value=1, max_value=100, value=3)
|
15 |
+
|
16 |
+
# Trigger the search when the button is clicked
|
17 |
+
if st.button("Search"):
|
18 |
+
# Construct the request payload
|
19 |
+
payload = {
|
20 |
+
"id": id,
|
21 |
+
"prompt": prompt,
|
22 |
+
"k": k
|
23 |
+
}
|
24 |
+
|
25 |
+
# Make the POST request
|
26 |
+
response = requests.post(url, json=payload)
|
27 |
+
|
28 |
+
# Handle the response
|
29 |
+
if response.status_code == 200:
|
30 |
+
results = response.json()
|
31 |
+
st.write(results)
|
32 |
+
else:
|
33 |
+
st.error(f"Error: {response.status_code} - {response.text}")
|
environment.yaml
ADDED
Binary file (4.77 kB). View file
|
|
main.py
ADDED
@@ -0,0 +1,58 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import sys
|
2 |
+
import os
|
3 |
+
import copy
|
4 |
+
import uvicorn
|
5 |
+
import socket
|
6 |
+
import logging
|
7 |
+
import datetime
|
8 |
+
from models.vectorizer import Vectorizer
|
9 |
+
from models.prompt_search_engine import PromptSearchEngine
|
10 |
+
from models.data_reader import load_prompts_from_jsonl
|
11 |
+
from models.Query import Query, Query_Multiple, SearchResponse, SimilarPrompt
|
12 |
+
from decouple import config
|
13 |
+
from fastapi import FastAPI, HTTPException, Depends, Body
|
14 |
+
from sentence_transformers import SentenceTransformer
|
15 |
+
|
16 |
+
|
17 |
+
|
18 |
+
prompt_path = r"C:\Users\jov2bg\Desktop\PromptSearch\models\prompts_data.jsonl"
|
19 |
+
|
20 |
+
|
21 |
+
app = FastAPI(title="Search Prompt Engine", description="API for prompt search", version="1.0")
|
22 |
+
|
23 |
+
prompts = load_prompts_from_jsonl(prompt_path)
|
24 |
+
search_engine = PromptSearchEngine()
|
25 |
+
search_engine.add_prompts_to_vector_database(prompts)
|
26 |
+
|
27 |
+
@app.get("/")
|
28 |
+
def read_root():
|
29 |
+
return {"message": "Prompt Search Engine is running!"}
|
30 |
+
|
31 |
+
@app.post("/search/")
|
32 |
+
async def search_prompts(query: Query, k: int = 3):
|
33 |
+
print(f'Prompt: {query.prompt}')
|
34 |
+
similar_prompts, distances = search_engine.most_similar(query.prompt, top_k=k)
|
35 |
+
print(f'Similar Prompts {similar_prompts}')
|
36 |
+
print(f'Distances {distances}')
|
37 |
+
print(40*'****')
|
38 |
+
# Format the response
|
39 |
+
response = [
|
40 |
+
SimilarPrompt(prompt=prompt, distance=float(distance))
|
41 |
+
for prompt, distance in zip(similar_prompts, distances)
|
42 |
+
]
|
43 |
+
|
44 |
+
return SearchResponse(results=response)
|
45 |
+
|
46 |
+
@app.post("/all_vectors_similarities/")
|
47 |
+
async def all_vectors(query: Query):
|
48 |
+
|
49 |
+
all_similarities = search_engine.cosine_similarity(query.prompt, search_engine.index)
|
50 |
+
response = [
|
51 |
+
SimilarPrompt(prompt=prompt, distance=float(distance))
|
52 |
+
for prompt, distance in all_similarities.items()
|
53 |
+
]
|
54 |
+
return SearchResponse(results=response)
|
55 |
+
|
56 |
+
if __name__ == "__main__":
|
57 |
+
localhost = socket.gethostbyname("localhost")
|
58 |
+
uvicorn.run(app, host=localhost, port=8000)
|
models/Query.py
ADDED
@@ -0,0 +1,20 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from pydantic import BaseModel
|
2 |
+
from typing import List
|
3 |
+
|
4 |
+
|
5 |
+
|
6 |
+
class Query(BaseModel):
|
7 |
+
|
8 |
+
id: str
|
9 |
+
prompt: str
|
10 |
+
|
11 |
+
class Query_Multiple(BaseModel):
|
12 |
+
prompt: List[Query]
|
13 |
+
|
14 |
+
|
15 |
+
class SimilarPrompt(BaseModel):
|
16 |
+
prompt: str
|
17 |
+
distance: float
|
18 |
+
|
19 |
+
class SearchResponse(BaseModel):
|
20 |
+
results: List[SimilarPrompt]
|
models/__pycache__/Query.cpython-312.pyc
ADDED
Binary file (1.11 kB). View file
|
|
models/__pycache__/data_reader.cpython-312.pyc
ADDED
Binary file (2.35 kB). View file
|
|
models/__pycache__/prompt_search_engine.cpython-312.pyc
ADDED
Binary file (3.53 kB). View file
|
|
models/__pycache__/vectorizer.cpython-312.pyc
ADDED
Binary file (1.82 kB). View file
|
|
models/data_reader.py
ADDED
@@ -0,0 +1,48 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from datasets import load_dataset
|
2 |
+
import json
|
3 |
+
|
4 |
+
|
5 |
+
# Load the dataset
|
6 |
+
base_url = "https://huggingface.co/datasets/jackyhate/text-to-image-2M/resolve/main/data_512_2M/data_{i:06d}.tar"
|
7 |
+
num_shards = 46 # Number of webdataset tar files
|
8 |
+
|
9 |
+
|
10 |
+
|
11 |
+
def download_data(base_url, num_shards):
|
12 |
+
# Download the data
|
13 |
+
urls = [base_url.format(i=i) for i in range(num_shards)]
|
14 |
+
dataset = load_dataset("webdataset", data_files={"train": urls}, split="train", streaming=True)
|
15 |
+
return dataset
|
16 |
+
|
17 |
+
def extract_prompts(dataset, json_file_path):
|
18 |
+
# Write data to the jsonl file
|
19 |
+
prompts = {}
|
20 |
+
with open(jsonl_file_path, 'w') as f:
|
21 |
+
for index, row in enumerate(dataset):
|
22 |
+
prompts[index] = row['json']['prompt']
|
23 |
+
f.write(json.dumps(prompts[index]) + '\n')
|
24 |
+
|
25 |
+
|
26 |
+
def read_data(jsonl_file_path):
|
27 |
+
|
28 |
+
# Read data from the jsonl file
|
29 |
+
with open(jsonl_file_path, 'r') as f:
|
30 |
+
for line in f:
|
31 |
+
row = json.loads(line)
|
32 |
+
print(row)
|
33 |
+
|
34 |
+
def load_prompts_from_jsonl(file_path):
|
35 |
+
prompts = []
|
36 |
+
with open(file_path, 'r') as f:
|
37 |
+
for line in f:
|
38 |
+
data = json.loads(line) # Each line is a JSON object
|
39 |
+
prompts.append(data) # Extract the 'prompt' field
|
40 |
+
return prompts
|
41 |
+
|
42 |
+
|
43 |
+
if __name__ == "__main__":
|
44 |
+
jsonl_file_path = r"C:\Users\jov2bg\Desktop\PromptSearch\models\prompts_data.jsonl"
|
45 |
+
num_shards = 1
|
46 |
+
dataset = download_data(num_shards, base_url)
|
47 |
+
extract_prompts(dataset, jsonl_file_path)
|
48 |
+
read_data(jsonl_file_path)
|
models/prompt_search_engine.py
ADDED
@@ -0,0 +1,48 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Sequence, List, Tuple
|
2 |
+
from models.vectorizer import Vectorizer
|
3 |
+
import numpy as np
|
4 |
+
from sentence_transformers import SentenceTransformer
|
5 |
+
import faiss
|
6 |
+
|
7 |
+
class PromptSearchEngine:
|
8 |
+
def __init__(self, model_name='bert-base-nli-mean-tokens'):
|
9 |
+
self.model = SentenceTransformer(model_name)
|
10 |
+
# Initialize FAISS index with right number of dimensions
|
11 |
+
self.embedding_dimension = self.model.get_sentence_embedding_dimension()
|
12 |
+
self.index = faiss.IndexFlatL2(self.embedding_dimension) # Euclidian distance index - brute force for small datasets
|
13 |
+
self.prompts_track = [] # To keep track of original prompts for returning results
|
14 |
+
|
15 |
+
|
16 |
+
def add_prompts_to_vector_database(self, prompts):
|
17 |
+
embeddings = self.model.encode(prompts)
|
18 |
+
self.index.add(np.array(embeddings).astype('float32'))
|
19 |
+
self.prompts_track.extend(prompts)
|
20 |
+
|
21 |
+
|
22 |
+
def most_similar(self, query, top_k=5):
|
23 |
+
# Encode the query
|
24 |
+
query_embedding = self.model.encode([query]).astype('float32')
|
25 |
+
|
26 |
+
# Optimizovana pretraga ali moramo promeniti vrstu indeksa
|
27 |
+
distances, indices = self.index.search(query_embedding, top_k)
|
28 |
+
|
29 |
+
# Retrieve the corresponding prompts for the found indices
|
30 |
+
similar_prompts = [self.prompts_track[idx] for idx in indices[0]]
|
31 |
+
|
32 |
+
return similar_prompts, distances[0] # Return both the similar prompts and their distances
|
33 |
+
|
34 |
+
|
35 |
+
def cosine_similarity(query_vector: np.ndarray, corpus_vectors: np.ndarray) -> np.ndarray:
|
36 |
+
"""Compute the cosine similarity between a query vector and a set of corpus vectors.
|
37 |
+
Args: query_vector: The query vector to compare against the corpus vectors. corpus_vectors: The set of corpus vectors to compare against the query vector.
|
38 |
+
Returns: The cosine similarity between the query vector and the corpus vectors.
|
39 |
+
"""
|
40 |
+
similarities = {}
|
41 |
+
for index, vector in enumerate(corpus_vectors):
|
42 |
+
if np.linalg.norm(vector) == 0:
|
43 |
+
raise ValueError("One of the corpus vectors has zero norm.")
|
44 |
+
cos_similarity = np.dot(vector, query_vector) / (np.linalg.norm(vector) * np.linalg.norm(query_vector))
|
45 |
+
similarities[index] = cos_similarity
|
46 |
+
return similarities
|
47 |
+
|
48 |
+
|
models/prompts_data.jsonl
ADDED
The diff for this file is too large to render.
See raw diff
|
|
models/vectorizer.py
ADDED
@@ -0,0 +1,33 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
from sentence_transformers import SentenceTransformer
|
3 |
+
import numpy as np
|
4 |
+
from typing import Sequence
|
5 |
+
import faiss
|
6 |
+
|
7 |
+
|
8 |
+
|
9 |
+
|
10 |
+
|
11 |
+
class Vectorizer:
|
12 |
+
def __init__(self, model) -> None:
|
13 |
+
"""Initialize the vectorizer with a pre-trained embedding model.
|
14 |
+
Args: model: The pre-trained embedding model to use for transforming prompts.
|
15 |
+
"""
|
16 |
+
self.model = model
|
17 |
+
self.index_size = 50000
|
18 |
+
self.index = faiss.IndexFlatIP(self.index_size)
|
19 |
+
self.cached_index_idx_to_retrieval_db_idx = []
|
20 |
+
|
21 |
+
|
22 |
+
def transform_and_add_to_index(self, prompts: Sequence[str]) -> np.ndarray:
|
23 |
+
"""Transform texts into numerical vectors using the specified model.
|
24 |
+
Args: prompts: The sequence of raw corpus prompts. Returns: Vectorized prompts
|
25 |
+
"""
|
26 |
+
embeddings = self.model.encode(prompts)
|
27 |
+
embedding_dimension = embeddings.shape[1]
|
28 |
+
print('Embedding dimension:', embedding_dimension)
|
29 |
+
|
30 |
+
self.index.add(np.array(embeddings))
|
31 |
+
|
32 |
+
|
33 |
+
|
requirements.txt
ADDED
Binary file (5.06 kB). View file
|
|