import time import re import pandas as pd import numpy as np import torch import torch.nn.functional as F from transformers import AutoTokenizer, AutoModel from tokenizers import Tokenizer, AddedToken import streamlit as st from st_click_detector import click_detector DEVICE = "cpu" MODEL_OPTIONS = ["msmarco-distilbert-base-tas-b", "all-mpnet-base-v2"] DESCRIPTION = """ # Semantic search **Enter your query and hit enter** Built with 🤗 Hugging Face's [transformers](https://huggingface.co/transformers/) library, [SentenceBert](https://www.sbert.net/) models, [Streamlit](https://streamlit.io/) and 44k movie descriptions from the Kaggle [Movies Dataset](https://www.kaggle.com/rounakbanik/the-movies-dataset) """ try: query_params = st.experimental_get_query_params() query_option = query_params['query'][0] #throws an exception when visiting http://host:port option_selected = st.sidebar.selectbox('Pick option', options, index=options.index(query_option)) except: # catch exception and set query param to predefined value st.experimental_set_query_params(query="Genomics") # defaults to dog query_params = st.experimental_get_query_params() query_option = query_params['query'][0] if 'query' not in st.session_state: #st.session_state['query'] = 'value' query = st.text_input("", value="artificial intelligence", key="query") else: query = st.text_input("", value=st.session_state["query"], key="query") #st.session_state.query = query if 'query' not in st.session_state: st.session_state.query = 'value' st.write(st.session_state.query) # Session state if 'key' not in st.session_state: st.session_state['key'] = 'value' if 'key' not in st.session_state: st.session_state.key = 'value' st.write(st.session_state.key) st.write(st.session_state) #st.session_state for key in st.session_state.keys(): del st.session_state[key] #st.text_input("Your name", key="name") #st.session_state.name @st.cache( show_spinner=False, hash_funcs={ AutoModel: lambda _: None, AutoTokenizer: lambda _: None, dict: lambda _: None, }, ) def load(): models, tokenizers, embeddings = [], [], [] for model_option in MODEL_OPTIONS: tokenizers.append( AutoTokenizer.from_pretrained(f"sentence-transformers/{model_option}") ) models.append( AutoModel.from_pretrained(f"sentence-transformers/{model_option}").to( DEVICE ) ) embeddings.append(np.load("embeddings.npy")) embeddings.append(np.load("embeddings2.npy")) df = pd.read_csv("movies.csv") return tokenizers, models, embeddings, df tokenizers, models, embeddings, df = load() def pooling(model_output): return model_output.last_hidden_state[:, 0] def compute_embeddings(texts): encoded_input = tokenizers[0]( texts, padding=True, truncation=True, return_tensors="pt" ).to(DEVICE) with torch.no_grad(): model_output = models[0](**encoded_input, return_dict=True) embeddings = pooling(model_output) return embeddings.cpu().numpy() def pooling2(model_output, attention_mask): token_embeddings = model_output[0] input_mask_expanded = ( attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float() ) return torch.sum(token_embeddings * input_mask_expanded, 1) / torch.clamp( input_mask_expanded.sum(1), min=1e-9 ) def compute_embeddings2(list_of_strings): encoded_input = tokenizers[1]( list_of_strings, padding=True, truncation=True, return_tensors="pt" ).to(DEVICE) with torch.no_grad(): model_output = models[1](**encoded_input) sentence_embeddings = pooling2(model_output, encoded_input["attention_mask"]) return F.normalize(sentence_embeddings, p=2, dim=1).cpu().numpy() @st.cache( show_spinner=False, hash_funcs={Tokenizer: lambda _: None, AddedToken: lambda _: None}, ) def semantic_search(query, model_id): start = time.time() if len(query.strip()) == 0: return "" if "[Similar:" not in query: if model_id == 0: query_embedding = compute_embeddings([query]) else: query_embedding = compute_embeddings2([query]) else: match = re.match(r"\[Similar:(\d{1,5}).*", query) if match: idx = int(match.groups()[0]) query_embedding = embeddings[model_id][idx : idx + 1, :] if query_embedding.shape[0] == 0: return "" else: return "" indices = np.argsort(embeddings[model_id] @ np.transpose(query_embedding)[:, 0])[ -1:-11:-1 ] if len(indices) == 0: return "" result = "
Computation time: {delay} seconds
{result}