Spaces:
Build error
Build error
import time | |
import re | |
import pandas as pd | |
import numpy as np | |
import torch | |
import torch.nn.functional as F | |
from transformers import AutoTokenizer, AutoModel | |
from tokenizers import Tokenizer, AddedToken | |
import streamlit as st | |
from st_click_detector import click_detector | |
# This lil dealio is my test of the new experiemntal primitives which promise to put cach in streamlit within striking distance of simulating cognitive episodic memory (personalized feelings about a moment through space time), and semantic memory (factual memories we are ready to share and communicate like your email address or physical address yo | |
# What impresses me about these two beautiful new streamlit persist prims is that one called the singleton can share memory across sessions (think all users yo) | |
#@st.experimental_singleton | |
#def get_sessionmaker(search_param): | |
# url = "https://en.wikipedia.org/wiki/" | |
# return url | |
#search_param = "Star_Trek:_Discovery" | |
#sm= get_sessionmaker(search_param) | |
# What is supercool about the second prim the memo is it makes unwieldy data very wieldy. Like the Lord of Rings in reverse re "you cannot wield it! none of us can." -> "You can wield it, now everyone can." | |
#@st.experimental_memo | |
#def factorial(n): | |
# if n < 1: | |
# return 1 | |
# return n * factorial(n - 1) | |
#em10 = factorial(10) | |
#em09 = factorial(9) # Returns instantly! | |
# radio button persistance - plan is to hydrate when selected and change url along with textbox and search | |
options = ["ai", "nlp", "iot", "vr", "genomics", "graph", "cognitive"] | |
query_params = st.experimental_get_query_params() | |
ix = 0 | |
if query_params: | |
try: | |
q0 = query_params['query'][0] | |
ix = options.index(q0) | |
except ValueError: | |
pass | |
selected_option = st.radio( | |
"Param", options, index=ix, key="query", on_change=update_params | |
) | |
st.experimental_set_query_params(option=selected_option) | |
# check if here for the first time then set the query | |
if 'query' not in st.session_state: | |
#st.session_state['query'] = 'value' | |
query = st.text_input("", value="AI", key="query") | |
st.session_state.query = 'AI' | |
st.write(st.session_state.query) | |
else: | |
query = st.text_input("", value=st.session_state["query"], key="query") | |
try: | |
st.session_state.query = query # if set already above. this prevents two interface elements setting it first time once | |
except: # catch exception and set query param to predefined value | |
print("Error cant set after init") | |
#if 'query' not in st.session_state: | |
# callback to update query param on selectbox change | |
#def update_params(): | |
# print("update1") | |
#try: | |
#st.experimental_set_query_params(option=st.session_state.query) | |
#except ValueError: | |
# pass | |
# Text Input, check the query params set the text input to query value if in session | |
try: | |
query_params = st.experimental_get_query_params() | |
query_option = query_params['query'][0] #throws an exception when visiting http://host:port | |
option_selected = st.sidebar.selectbox('Pick option', options, index=options.index(query_option)) | |
except: # catch exception and set query param to predefined value | |
st.experimental_set_query_params(query="Genomics") # set default | |
query_params = st.experimental_get_query_params() | |
query_option = query_params['query'][0] | |
DEVICE = "cpu" | |
MODEL_OPTIONS = ["msmarco-distilbert-base-tas-b", "all-mpnet-base-v2"] | |
DESCRIPTION = """ | |
# Semantic search | |
**Enter your query and hit enter** | |
Built with π€ Hugging Face's [transformers](https://huggingface.co/transformers/) library, [SentenceBert](https://www.sbert.net/) models, [Streamlit](https://streamlit.io/) and 44k movie descriptions from the Kaggle [Movies Dataset](https://www.kaggle.com/rounakbanik/the-movies-dataset) | |
""" | |
# Session state - search parms | |
if 'key' not in st.session_state: | |
st.session_state['key'] = 'value' | |
if 'key' not in st.session_state: | |
st.session_state.key = 'value' | |
st.write(st.session_state.key) | |
st.write(st.session_state) | |
#st.session_state | |
for key in st.session_state.keys(): | |
del st.session_state[key] | |
#st.text_input("Your name", key="name") | |
#st.session_state.name | |
def load(): | |
models, tokenizers, embeddings = [], [], [] | |
for model_option in MODEL_OPTIONS: | |
tokenizers.append( | |
AutoTokenizer.from_pretrained(f"sentence-transformers/{model_option}") | |
) | |
models.append( | |
AutoModel.from_pretrained(f"sentence-transformers/{model_option}").to( | |
DEVICE | |
) | |
) | |
embeddings.append(np.load("embeddings.npy")) | |
embeddings.append(np.load("embeddings2.npy")) | |
df = pd.read_csv("movies.csv") | |
return tokenizers, models, embeddings, df | |
tokenizers, models, embeddings, df = load() | |
def pooling(model_output): | |
return model_output.last_hidden_state[:, 0] | |
def compute_embeddings(texts): | |
encoded_input = tokenizers[0]( | |
texts, padding=True, truncation=True, return_tensors="pt" | |
).to(DEVICE) | |
with torch.no_grad(): | |
model_output = models[0](**encoded_input, return_dict=True) | |
embeddings = pooling(model_output) | |
return embeddings.cpu().numpy() | |
def pooling2(model_output, attention_mask): | |
token_embeddings = model_output[0] | |
input_mask_expanded = ( | |
attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float() | |
) | |
return torch.sum(token_embeddings * input_mask_expanded, 1) / torch.clamp( | |
input_mask_expanded.sum(1), min=1e-9 | |
) | |
def compute_embeddings2(list_of_strings): | |
encoded_input = tokenizers[1]( | |
list_of_strings, padding=True, truncation=True, return_tensors="pt" | |
).to(DEVICE) | |
with torch.no_grad(): | |
model_output = models[1](**encoded_input) | |
sentence_embeddings = pooling2(model_output, encoded_input["attention_mask"]) | |
return F.normalize(sentence_embeddings, p=2, dim=1).cpu().numpy() | |
def semantic_search(query, model_id): | |
start = time.time() | |
if len(query.strip()) == 0: | |
return "" | |
if "[Similar:" not in query: | |
if model_id == 0: | |
query_embedding = compute_embeddings([query]) | |
else: | |
query_embedding = compute_embeddings2([query]) | |
else: | |
match = re.match(r"\[Similar:(\d{1,5}).*", query) | |
if match: | |
idx = int(match.groups()[0]) | |
query_embedding = embeddings[model_id][idx : idx + 1, :] | |
if query_embedding.shape[0] == 0: | |
return "" | |
else: | |
return "" | |
indices = np.argsort(embeddings[model_id] @ np.transpose(query_embedding)[:, 0])[ | |
-1:-11:-1 | |
] | |
if len(indices) == 0: | |
return "" | |
result = "<ol>" | |
for i in indices: | |
result += f"<li style='padding-top: 10px'><b>{df.iloc[i].title}</b> ({df.iloc[i].release_date}). {df.iloc[i].overview} " | |
result += f"<a id='{i}' href='#'>Similar movies</a></li>" | |
delay = "%.3f" % (time.time() - start) | |
return f"<p><i>Computation time: {delay} seconds</i></p>{result}</ol>" | |
st.sidebar.markdown(DESCRIPTION) | |
model_choice = st.sidebar.selectbox("Similarity model", options=MODEL_OPTIONS) | |
model_id = 0 if model_choice == MODEL_OPTIONS[0] else 1 | |
clicked = click_detector(semantic_search(query, model_id)) | |
if clicked != "": | |
st.markdown(clicked) | |
change_query = False | |
if "last_clicked" not in st.session_state: | |
st.session_state["last_clicked"] = clicked | |
change_query = True | |
else: | |
if clicked != st.session_state["last_clicked"]: | |
st.session_state["last_clicked"] = clicked | |
change_query = True | |
if change_query: | |
st.session_state["query"] = f"[Similar:{clicked}] {df.iloc[int(clicked)].title}" | |
st.experimental_rerun() | |