import streamlit as st import pandas as pd import numpy as np import pickle from huggingface_hub import hf_hub_download from sentence_transformers import SentenceTransformer, util from langdetect import detect import plotly.express as px from collections import Counter # sidebar with st.sidebar: st.header("Examples:") st.markdown("This search finds content in Medium .") # main content st.header("Semantic Search Engine on [Medium](https://medium.com/) articles") st.markdown("This is a small demo project of a semantic search engine over a dataset of ~190k Medium articles.") st_placeholder_loading = st.empty() st_placeholder_loading.text('Loading medium articles data...') @st.cache(allow_output_mutation=True) def load_data(): df_articles = pd.read_csv(hf_hub_download("fabiochiu/medium-articles", repo_type="dataset", filename="medium_articles_no_text.csv")) corpus_embeddings = pickle.load(open(hf_hub_download("fabiochiu/medium-articles", repo_type="dataset", filename="medium_articles_embeddings.pickle"), "rb")) embedder = SentenceTransformer('all-MiniLM-L6-v2') return df_articles, corpus_embeddings, embedder df_articles, corpus_embeddings, embedder = load_data() st_placeholder_loading.empty() n_top_tags = 20 @st.cache() def load_chart_top_tags(): # Occurrences of the top 50 tags print("we") all_tags = [tag for tags_list in df_articles["tags"] for tag in eval(tags_list)] d_tags_counter = Counter(all_tags) tags, frequencies = list(zip(*d_tags_counter.most_common(n=n_top_tags))) fig = px.bar(x=tags, y=frequencies) fig.update_xaxes(title="tags") fig.update_yaxes(title="frequencies") return fig fig_top_tags = load_chart_top_tags() st_query = st.text_input("Write your query here", max_chars=100) def on_click_search(): if st_query != "": query_embedding = embedder.encode(st_query, convert_to_tensor=True) top_k = 10 hits = util.semantic_search(query_embedding, corpus_embeddings, top_k=top_k*2)[0] article_dicts = [] for hit in hits: score = hit['score'] article_row = df_articles.iloc[hit['corpus_id']] try: detected_lang = detect(article_row["title"]) except: detected_lang = "" if detected_lang == "en" and len(article_row["title"]) >= 10: article_dicts.append({ "title": article_row['title'], "url": article_row['url'], "score": score }) if len(article_dicts) >= top_k: break st.session_state.article_dicts = article_dicts st.session_state.empty_query = False else: st.session_state.article_dicts = [] st.session_state.empty_query = True st.button("Search", on_click=on_click_search) if st_query != "": st.session_state.empty_query = False on_click_search() else: st.session_state.empty_query = True if not st.session_state.empty_query: st.markdown("### Results") st.markdown("*Scores between parentheses represent the similarity between the article and the query.*") for article_dict in st.session_state.article_dicts: st.markdown(f"""- [{article_dict['title'].capitalize()}]({article_dict['url']}) ({article_dict['score']:.2f})""") elif st.session_state.empty_query and "article_dicts" in st.session_state: st.markdown("Please write a query and then press the search button.")