Spaces:
Runtime error
Runtime error
import sys | |
import os | |
import re | |
import torch | |
import base64 | |
import pandas as pd | |
import streamlit as st | |
from sklearn.metrics.pairwise import cosine_similarity | |
from sentence_transformers import SentenceTransformer | |
sys.path.insert(0, os.getcwd()) | |
st.title("ArXiV Paper Recommender") | |
def set_background(main_bg): | |
main_bg_ext = "jpg" | |
st.markdown( | |
f""" | |
<style> | |
.stApp {{ | |
background: url(data:image/{main_bg_ext};base64,{base64.b64encode(open(main_bg, "rb").read()).decode()}); | |
background-size: cover | |
}} | |
</style> | |
""", | |
unsafe_allow_html=True | |
) | |
set_background("images/p1.jpg") | |
topic = st.text_input('What kind of paper would you wish to be recommended?', 'I want to read a paper on Bayesian Optimization!') | |
number = st.number_input('Show me these many papers.', min_value=1, max_value=10, value=3, step=1) | |
def process_text(text): | |
rep = {"\n": " ", "(": "", ")": "", "!": ""} | |
rep = dict((re.escape(k), v) for k, v in rep.items()) | |
pattern = re.compile("|".join(rep.keys())) | |
text = pattern.sub(lambda m: rep[re.escape(m.group(0))], text).lower() | |
return text | |
def get_cosine_similarity(feature_vec_1, feature_vec_2): | |
return cosine_similarity(feature_vec_1.reshape(1, -1), feature_vec_2.reshape(1, -1))[0][0] | |
def get_model(): | |
device = 'cuda' if torch.cuda.is_available() else None | |
model = SentenceTransformer('paraphrase-MiniLM-L6-v2', device=device) | |
return model | |
if st.button("GO!"): | |
prompt = process_text(topic) | |
model = get_model() | |
prompt_embedded = model.encode(prompt) | |
df_embed = pd.read_pickle('data/embeddings_pkl.pkl').drop_duplicates(subset=['titles']) | |
df_embed["similarity_scores"] = df_embed["abstracts_embeddings"].apply(lambda x: get_cosine_similarity(x, prompt_embedded)) | |
top_n = df_embed.nlargest(number, 'similarity_scores').head(5)["titles"].to_list() | |
st.text(" ") | |
st.subheader('Have a look at the following: :sunglasses:') | |
for rec_title in top_n: | |
st.markdown("-> " + rec_title) |