SearchMesh / app.py
nsorros's picture
Return related terms when no results are found
e18db08
raw history blame
No virus
2.29 kB
import difflib
from collections import Counter
import streamlit as st
import pandas as pd
import srsly
def search(query):
results = []
for grant in grants:
if query in grant["tags"]:
results.append({"title": grant["title"], "tags": grant["tags"]})
st.session_state["results"] = results
st.header("Search πŸ”Ž grants using MeSH πŸ”–")
st.sidebar.header("Information β„Ή")
st.sidebar.write(
"A complete list of MeSH tags can be found here https://meshb.nlm.nih.gov/treeView"
)
st.sidebar.write("The grants data can be found https://www.threesixtygiving.org/")
st.sidebar.write(
"The model used to tag grants is https://huggingface.co/Wellcome/WellcomeBertMesh"
)
st.sidebar.header("Parameters")
nb_results = st.sidebar.slider(
"Number of results to display", value=20, min_value=1, max_value=100
)
if "grants" not in st.session_state:
st.session_state["grants"] = list(srsly.read_jsonl("tagged_grants.jsonl"))
grants = st.session_state["grants"]
if "tags" not in st.session_state:
st.session_state["tags"] = list(set([tag for grant in grants for tag in grant["tags"]]))
tags = st.session_state["tags"]
query = st.text_input("", value="Malaria")
st.button("Search πŸ”Ž", on_click=search, kwargs={"query": query})
if "results" in st.session_state:
st.caption("Related MeSH terms")
if st.session_state["results"]:
retrieved_tags = [tag for res in st.session_state["results"] for tag in res["tags"]]
most_common_tags = [tag for tag, _ in Counter(retrieved_tags).most_common(20)]
else:
most_common_tags = difflib.get_close_matches(query, tags, n=20)
columns = st.columns(5)
for row_i in range(3):
for col_i, col in enumerate(columns):
with col:
tag_i = row_i * 5 + col_i
if tag_i < len(most_common_tags):
tag = most_common_tags[tag_i]
st.button(tag, on_click=search, kwargs={"query": tag})
results = st.session_state["results"]
st.caption(f"Found {len(results)}. Displaying {nb_results}")
st.download_button(
"Download results",
data=pd.DataFrame(results).to_csv(),
file_name="results.csv",
mime="text/csv",
)
st.table(results[:nb_results])