semantic / app.py
dhmeltzer's picture
Update app.py
2f1f4ba
raw
history blame
2.83 kB
import faiss
import pickle
import datasets
import numpy as np
import requests
import streamlit as st
from vector_engine.utils import vector_search
from transformers import AutoModel, AutoTokenizer
from datasets import load_dataset
#@st.cache
@st.cache_data
def read_data(dataset_repo='dhmeltzer/asks_validation_embedded'):
"""Read the data from huggingface."""
return load_dataset(dataset_repo)['validation_asks']
@st.cache_data
def load_faiss_index(path_to_faiss="./faiss_index_small.pickle"):
"""Load and deserialize the Faiss index."""
with open(path_to_faiss, "rb") as h:
data = pickle.load(h)
return faiss.deserialize_index(data)
def main():
# Load data and models
data = read_data()
#model = load_bert_model()
#tok = load_tokenizer()
faiss_index = load_faiss_index()
model_id="sentence-transformers/nli-distilbert-base"
api_url = f"https://api-inference.huggingface.co/pipeline/feature-extraction/{model_id}"
headers = {"Authorization": f"Bearer {st.secrets['HF_token']}"}
def query(texts):
response = requests.post(api_url, headers=headers, json={"inputs": texts, "options":{"wait_for_model":True}})
return response.json()
st.title("Vector-based searches with Sentence Transformers and Faiss")
# User search
user_input = st.text_area("Search box", "What is spacetime made out of?")
# Filters
st.sidebar.markdown("**Filters**")
num_results = st.sidebar.slider("Number of search results", 1, 50, 1)
vector = query([user_input])
# Fetch results
if user_input:
# Get paper IDs
_, I = faiss_index.search(np.array(vector).astype("float32"), k=num_results)
# Get individual results
for id_ in I.flatten().tolist():
row = data[id_]
answers=row['answers']['text']
answers_URLs = row['answers_urls']['url']
for k in range(len(answers_URLs)):
answers = [answer.replace(f'_URL_{k}_',answers_URLs[k]) for answer in answers]
st.write(
f"""**Title**: {row['title']}
\n
**Score**: {row[answers]['score'][0]}
\n
**Top Answer**: {answers[0]}
"""
)
st.write("-"*20)
if __name__ == "__main__":
main()
#@st.cache(allow_output_mutation=True)
#def load_bert_model(name="nli-distilbert-base"):
# """Instantiate a sentence-level DistilBERT model."""
# return AutoModel.from_pretrained(f'sentence-transformers/{name}')
#
#@st.cache(allow_output_mutation=True)
#def load_tokenizer(name="nli-distilbert-base"):
# return AutoTokenizer.from_pretrained(f'sentence-transformers/{name}')
#@st.cache(allow_output_mutation=True)