import time import re import pandas as pd import numpy as np import torch import torch.nn.functional as F from transformers import AutoTokenizer, AutoModel from tokenizers import Tokenizer, AddedToken import streamlit as st from st_click_detector import click_detector DEVICE = "cpu" MODEL_OPTIONS = ["msmarco-distilbert-base-tas-b", "all-mpnet-base-v2"] DESCRIPTION = """ # Semantic search **Enter your query and hit enter** Built with 🤗 Hugging Face's [transformers](https://huggingface.co/transformers/) library, [SentenceBert](https://www.sbert.net/) models, [Streamlit](https://streamlit.io/) and 44k movie descriptions from the Kaggle [Movies Dataset](https://www.kaggle.com/rounakbanik/the-movies-dataset) """ @st.cache( show_spinner=False, hash_funcs={ AutoModel: lambda _: None, AutoTokenizer: lambda _: None, dict: lambda _: None, }, ) def load(): models, tokenizers, embeddings = [], [], [] for model_option in MODEL_OPTIONS: tokenizers.append( AutoTokenizer.from_pretrained(f"sentence-transformers/{model_option}") ) models.append( AutoModel.from_pretrained(f"sentence-transformers/{model_option}").to( DEVICE ) ) embeddings.append(np.load("embeddings.npy")) embeddings.append(np.load("embeddings2.npy")) df = pd.read_csv("movies.csv") return tokenizers, models, embeddings, df tokenizers, models, embeddings, df = load() def pooling(model_output): return model_output.last_hidden_state[:, 0] def compute_embeddings(texts): encoded_input = tokenizers[0]( texts, padding=True, truncation=True, return_tensors="pt" ).to(DEVICE) with torch.no_grad(): model_output = models[0](**encoded_input, return_dict=True) embeddings = pooling(model_output) return embeddings.cpu().numpy() def pooling2(model_output, attention_mask): token_embeddings = model_output[0] input_mask_expanded = ( attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float() ) return torch.sum(token_embeddings * input_mask_expanded, 1) / torch.clamp( input_mask_expanded.sum(1), min=1e-9 ) def compute_embeddings2(list_of_strings): encoded_input = tokenizers[1]( list_of_strings, padding=True, truncation=True, return_tensors="pt" ).to(DEVICE) with torch.no_grad(): model_output = models[1](**encoded_input) sentence_embeddings = pooling2(model_output, encoded_input["attention_mask"]) return F.normalize(sentence_embeddings, p=2, dim=1).cpu().numpy() @st.cache( show_spinner=False, hash_funcs={Tokenizer: lambda _: None, AddedToken: lambda _: None}, ) def semantic_search(query, model_id): start = time.time() if len(query.strip()) == 0: return "" if "[Similar:" not in query: if model_id == 0: query_embedding = compute_embeddings([query]) else: query_embedding = compute_embeddings2([query]) else: match = re.match(r"\[Similar:(\d{1,5}).*", query) if match: idx = int(match.groups()[0]) query_embedding = embeddings[model_id][idx : idx + 1, :] if query_embedding.shape[0] == 0: return "" else: return "" indices = np.argsort(embeddings[model_id] @ np.transpose(query_embedding)[:, 0])[ -1:-11:-1 ] if len(indices) == 0: return "" result = "
    " for i in indices: result += f"
  1. {df.iloc[i].title} ({df.iloc[i].release_date}). {df.iloc[i].overview} " result += f"Similar movies
  2. " delay = "%.3f" % (time.time() - start) return f"

    Computation time: {delay} seconds

    {result}
" st.sidebar.markdown(DESCRIPTION) model_choice = st.sidebar.selectbox("Similarity model", options=MODEL_OPTIONS) model_id = 0 if model_choice == MODEL_OPTIONS[0] else 1 if "query" in st.session_state: query = st.text_input("", value=st.session_state["query"]) else: query = st.text_input("", value="artificial intelligence") clicked = click_detector(semantic_search(query, model_id)) if clicked != "": st.markdown(clicked) change_query = False if "last_clicked" not in st.session_state: st.session_state["last_clicked"] = clicked change_query = True else: if clicked != st.session_state["last_clicked"]: st.session_state["last_clicked"] = clicked change_query = True if change_query: st.session_state["query"] = f"[Similar:{clicked}] {df.iloc[int(clicked)].title}" st.experimental_rerun()