awacke1's picture
Update app.py
5891c7a
import streamlit as st
import pandas as pd
import numpy as np
import pickle
from huggingface_hub import hf_hub_download
from sentence_transformers import SentenceTransformer, util
from langdetect import detect
import plotly.express as px
from collections import Counter
# sidebar
with st.sidebar:
st.header("Examples:")
st.markdown("This search finds content in Medium .")
# main content
st.header("Semantic Search Engine on [Medium](https://medium.com/) articles")
st.markdown("This is a small demo project of a semantic search engine over a dataset of ~190k Medium articles.")
st_placeholder_loading = st.empty()
st_placeholder_loading.text('Loading medium articles data...')
@st.cache(allow_output_mutation=True)
def load_data():
df_articles = pd.read_csv(hf_hub_download("fabiochiu/medium-articles", repo_type="dataset", filename="medium_articles_no_text.csv"))
corpus_embeddings = pickle.load(open(hf_hub_download("fabiochiu/medium-articles", repo_type="dataset", filename="medium_articles_embeddings.pickle"), "rb"))
embedder = SentenceTransformer('all-MiniLM-L6-v2')
return df_articles, corpus_embeddings, embedder
df_articles, corpus_embeddings, embedder = load_data()
st_placeholder_loading.empty()
n_top_tags = 20
@st.cache()
def load_chart_top_tags():
# Occurrences of the top 50 tags
print("we")
all_tags = [tag for tags_list in df_articles["tags"] for tag in eval(tags_list)]
d_tags_counter = Counter(all_tags)
tags, frequencies = list(zip(*d_tags_counter.most_common(n=n_top_tags)))
fig = px.bar(x=tags, y=frequencies)
fig.update_xaxes(title="tags")
fig.update_yaxes(title="frequencies")
return fig
fig_top_tags = load_chart_top_tags()
st_query = st.text_input("Write your query here", max_chars=100)
def on_click_search():
if st_query != "":
query_embedding = embedder.encode(st_query, convert_to_tensor=True)
top_k = 10
hits = util.semantic_search(query_embedding, corpus_embeddings, top_k=top_k*2)[0]
article_dicts = []
for hit in hits:
score = hit['score']
article_row = df_articles.iloc[hit['corpus_id']]
try:
detected_lang = detect(article_row["title"])
except:
detected_lang = ""
if detected_lang == "en" and len(article_row["title"]) >= 10:
article_dicts.append({
"title": article_row['title'],
"url": article_row['url'],
"score": score
})
if len(article_dicts) >= top_k:
break
st.session_state.article_dicts = article_dicts
st.session_state.empty_query = False
else:
st.session_state.article_dicts = []
st.session_state.empty_query = True
st.button("Search", on_click=on_click_search)
if st_query != "":
st.session_state.empty_query = False
on_click_search()
else:
st.session_state.empty_query = True
if not st.session_state.empty_query:
st.markdown("### Results")
st.markdown("*Scores between parentheses represent the similarity between the article and the query.*")
for article_dict in st.session_state.article_dicts:
st.markdown(f"""- [{article_dict['title'].capitalize()}]({article_dict['url']}) ({article_dict['score']:.2f})""")
elif st.session_state.empty_query and "article_dicts" in st.session_state:
st.markdown("Please write a query and then press the search button.")