|
import faiss |
|
import pickle |
|
import datasets |
|
import numpy as np |
|
import requests |
|
import streamlit as st |
|
from vector_engine.utils import vector_search |
|
from transformers import AutoModel, AutoTokenizer |
|
|
|
from datasets import load_dataset |
|
|
|
|
|
@st.cache_data |
|
def read_data(dataset_repo='dhmeltzer/asks_validation_embedded'): |
|
"""Read the data from huggingface.""" |
|
return load_dataset(dataset_repo)['validation_asks'] |
|
|
|
@st.cache_data |
|
def load_faiss_index(path_to_faiss="./faiss_index_small.pickle"): |
|
"""Load and deserialize the Faiss index.""" |
|
with open(path_to_faiss, "rb") as h: |
|
data = pickle.load(h) |
|
return faiss.deserialize_index(data) |
|
|
|
def main(): |
|
|
|
data = read_data() |
|
|
|
|
|
faiss_index = load_faiss_index() |
|
|
|
model_id="sentence-transformers/nli-distilbert-base" |
|
|
|
api_url = f"https://api-inference.huggingface.co/pipeline/feature-extraction/{model_id}" |
|
headers = {"Authorization": f"Bearer {st.secrets['HF_token']}"} |
|
|
|
def query(texts): |
|
response = requests.post(api_url, headers=headers, json={"inputs": texts, "options":{"wait_for_model":True}}) |
|
return response.json() |
|
|
|
|
|
st.title("Vector-based searches with Sentence Transformers and Faiss") |
|
|
|
|
|
user_input = st.text_area("Search box", "What is spacetime made out of?") |
|
|
|
|
|
st.sidebar.markdown("**Filters**") |
|
num_results = st.sidebar.slider("Number of search results", 1, 50, 1) |
|
|
|
vector = query([user_input]) |
|
|
|
if user_input: |
|
|
|
_, I = faiss_index.search(np.array(vector).astype("float32"), k=num_results) |
|
|
|
|
|
for id_ in I.flatten().tolist(): |
|
row = data[id_] |
|
|
|
answers=row['answers']['text'] |
|
answers_URLs = row['answers_urls']['url'] |
|
for k in range(len(answers_URLs)): |
|
answers = [answer.replace(f'_URL_{k}_',answers_URLs[k]) for answer in answers] |
|
|
|
|
|
st.write( |
|
f"""**Title**: {row['title']} |
|
\n |
|
**Score**: {row[answers]['score'][0]} |
|
\n |
|
**Top Answer**: {answers[0]} |
|
""" |
|
) |
|
st.write("-"*20) |
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
main() |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|