Spaces:
Build error
Build error
import streamlit as st | |
import pandas as pd | |
import numpy as np | |
import pickle | |
from huggingface_hub import hf_hub_download | |
from sentence_transformers import SentenceTransformer, util | |
from langdetect import detect | |
import plotly.express as px | |
from collections import Counter | |
# sidebar | |
with st.sidebar: | |
st.header("Examples:") | |
st.markdown("This search finds content in Medium .") | |
# main content | |
st.header("Semantic Search Engine on [Medium](https://medium.com/) articles") | |
st.markdown("This is a small demo project of a semantic search engine over a dataset of ~190k Medium articles.") | |
st_placeholder_loading = st.empty() | |
st_placeholder_loading.text('Loading medium articles data...') | |
def load_data(): | |
df_articles = pd.read_csv(hf_hub_download("fabiochiu/medium-articles", repo_type="dataset", filename="medium_articles_no_text.csv")) | |
corpus_embeddings = pickle.load(open(hf_hub_download("fabiochiu/medium-articles", repo_type="dataset", filename="medium_articles_embeddings.pickle"), "rb")) | |
embedder = SentenceTransformer('all-MiniLM-L6-v2') | |
return df_articles, corpus_embeddings, embedder | |
df_articles, corpus_embeddings, embedder = load_data() | |
st_placeholder_loading.empty() | |
n_top_tags = 20 | |
def load_chart_top_tags(): | |
# Occurrences of the top 50 tags | |
print("we") | |
all_tags = [tag for tags_list in df_articles["tags"] for tag in eval(tags_list)] | |
d_tags_counter = Counter(all_tags) | |
tags, frequencies = list(zip(*d_tags_counter.most_common(n=n_top_tags))) | |
fig = px.bar(x=tags, y=frequencies) | |
fig.update_xaxes(title="tags") | |
fig.update_yaxes(title="frequencies") | |
return fig | |
fig_top_tags = load_chart_top_tags() | |
st_query = st.text_input("Write your query here", max_chars=100) | |
def on_click_search(): | |
if st_query != "": | |
query_embedding = embedder.encode(st_query, convert_to_tensor=True) | |
top_k = 10 | |
hits = util.semantic_search(query_embedding, corpus_embeddings, top_k=top_k*2)[0] | |
article_dicts = [] | |
for hit in hits: | |
score = hit['score'] | |
article_row = df_articles.iloc[hit['corpus_id']] | |
try: | |
detected_lang = detect(article_row["title"]) | |
except: | |
detected_lang = "" | |
if detected_lang == "en" and len(article_row["title"]) >= 10: | |
article_dicts.append({ | |
"title": article_row['title'], | |
"url": article_row['url'], | |
"score": score | |
}) | |
if len(article_dicts) >= top_k: | |
break | |
st.session_state.article_dicts = article_dicts | |
st.session_state.empty_query = False | |
else: | |
st.session_state.article_dicts = [] | |
st.session_state.empty_query = True | |
st.button("Search", on_click=on_click_search) | |
if st_query != "": | |
st.session_state.empty_query = False | |
on_click_search() | |
else: | |
st.session_state.empty_query = True | |
if not st.session_state.empty_query: | |
st.markdown("### Results") | |
st.markdown("*Scores between parentheses represent the similarity between the article and the query.*") | |
for article_dict in st.session_state.article_dicts: | |
st.markdown(f"""- [{article_dict['title'].capitalize()}]({article_dict['url']}) ({article_dict['score']:.2f})""") | |
elif st.session_state.empty_query and "article_dicts" in st.session_state: | |
st.markdown("Please write a query and then press the search button.") | |