Spaces:
Sleeping
Sleeping
| import time | |
| import re | |
| import pandas as pd | |
| import numpy as np | |
| import torch | |
| import torch.nn.functional as F | |
| from transformers import AutoTokenizer, AutoModel | |
| from tokenizers import Tokenizer, AddedToken | |
| import streamlit as st | |
| from st_click_detector import click_detector | |
| DEVICE = "cpu" | |
| MODEL_OPTIONS = ["msmarco-distilbert-base-tas-b", "all-mpnet-base-v2"] | |
| DESCRIPTION = """ | |
| # Semantic search | |
| **Enter your query and hit enter** | |
| Built with 🤗 Hugging Face's [transformers](https://huggingface.co/transformers/) library, [SentenceBert](https://www.sbert.net/) models, [Streamlit](https://streamlit.io/) and 44k movie descriptions from the Kaggle [Movies Dataset](https://www.kaggle.com/rounakbanik/the-movies-dataset) | |
| """ | |
| def load(): | |
| models, tokenizers, embeddings = [], [], [] | |
| for model_option in MODEL_OPTIONS: | |
| tokenizers.append( | |
| AutoTokenizer.from_pretrained(f"sentence-transformers/{model_option}") | |
| ) | |
| models.append( | |
| AutoModel.from_pretrained(f"sentence-transformers/{model_option}").to( | |
| DEVICE | |
| ) | |
| ) | |
| embeddings.append(np.load("embeddings.npy")) | |
| embeddings.append(np.load("embeddings2.npy")) | |
| df = pd.read_csv("movies.csv") | |
| return tokenizers, models, embeddings, df | |
| tokenizers, models, embeddings, df = load() | |
| def pooling(model_output): | |
| return model_output.last_hidden_state[:, 0] | |
| def compute_embeddings(texts): | |
| encoded_input = tokenizers[0]( | |
| texts, padding=True, truncation=True, return_tensors="pt" | |
| ).to(DEVICE) | |
| with torch.no_grad(): | |
| model_output = models[0](**encoded_input, return_dict=True) | |
| embeddings = pooling(model_output) | |
| return embeddings.cpu().numpy() | |
| def pooling2(model_output, attention_mask): | |
| token_embeddings = model_output[0] | |
| input_mask_expanded = ( | |
| attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float() | |
| ) | |
| return torch.sum(token_embeddings * input_mask_expanded, 1) / torch.clamp( | |
| input_mask_expanded.sum(1), min=1e-9 | |
| ) | |
| def compute_embeddings2(list_of_strings): | |
| encoded_input = tokenizers[1]( | |
| list_of_strings, padding=True, truncation=True, return_tensors="pt" | |
| ).to(DEVICE) | |
| with torch.no_grad(): | |
| model_output = models[1](**encoded_input) | |
| sentence_embeddings = pooling2(model_output, encoded_input["attention_mask"]) | |
| return F.normalize(sentence_embeddings, p=2, dim=1).cpu().numpy() | |
| def semantic_search(query, model_id): | |
| start = time.time() | |
| if len(query.strip()) == 0: | |
| return "" | |
| if "[Similar:" not in query: | |
| if model_id == 0: | |
| query_embedding = compute_embeddings([query]) | |
| else: | |
| query_embedding = compute_embeddings2([query]) | |
| else: | |
| match = re.match(r"\[Similar:(\d{1,5}).*", query) | |
| if match: | |
| idx = int(match.groups()[0]) | |
| query_embedding = embeddings[model_id][idx : idx + 1, :] | |
| if query_embedding.shape[0] == 0: | |
| return "" | |
| else: | |
| return "" | |
| indices = np.argsort(embeddings[model_id] @ np.transpose(query_embedding)[:, 0])[ | |
| -1:-11:-1 | |
| ] | |
| if len(indices) == 0: | |
| return "" | |
| result = "<ol>" | |
| for i in indices: | |
| result += f"<li style='padding-top: 10px'><b>{df.iloc[i].title}</b> ({df.iloc[i].release_date}). {df.iloc[i].overview} " | |
| result += f"<a id='{i}' href='#'>Similar movies</a></li>" | |
| delay = "%.3f" % (time.time() - start) | |
| return f"<p><i>Computation time: {delay} seconds</i></p>{result}</ol>" | |
| st.sidebar.markdown(DESCRIPTION) | |
| model_choice = st.sidebar.selectbox("Similarity model", options=MODEL_OPTIONS) | |
| model_id = 0 if model_choice == MODEL_OPTIONS[0] else 1 | |
| if "query" in st.session_state: | |
| query = st.text_input("", value=st.session_state["query"]) | |
| else: | |
| query = st.text_input("", value="time travel") | |
| clicked = click_detector(semantic_search(query, model_id)) | |
| if clicked != "": | |
| change_query = False | |
| if "last_clicked" not in st.session_state: | |
| st.session_state["last_clicked"] = clicked | |
| change_query = True | |
| else: | |
| if clicked != st.session_state["last_clicked"]: | |
| st.session_state["last_clicked"] = clicked | |
| change_query = True | |
| if change_query: | |
| st.session_state["query"] = f"[Similar:{clicked}] {df.iloc[int(clicked)].title}" | |
| st.experimental_rerun() | |