import difflib from collections import Counter import streamlit as st import pandas as pd import srsly def search(query): results = [] for grant in grants: if query in grant["tags"]: results.append({"title": grant["title"], "tags": grant["tags"]}) st.session_state["results"] = results st.header("Search 🔎 grants using MeSH 🔖") st.sidebar.header("Information ℹ") st.sidebar.write( "A complete list of MeSH tags can be found here https://meshb.nlm.nih.gov/treeView" ) st.sidebar.write("The grants data can be found at [https://www.threesixtygiving.org/](https://data.threesixtygiving.org/). They are published under a [CC BY 4.0](https://creativecommons.org/licenses/by/4.0/) license.") st.sidebar.write( "The model used to tag grants is https://huggingface.co/Wellcome/WellcomeBertMesh" ) st.sidebar.header("Parameters") nb_results = st.sidebar.slider( "Number of results to display", value=20, min_value=1, max_value=100 ) if "grants" not in st.session_state: st.session_state["grants"] = list(srsly.read_jsonl("tagged_grants.jsonl")) grants = st.session_state["grants"] if "tags" not in st.session_state: st.session_state["tags"] = list(set([tag for grant in grants for tag in grant["tags"]])) tags = st.session_state["tags"] query = st.text_input("", value="Malaria") st.button("Search 🔎", on_click=search, kwargs={"query": query}) if "results" in st.session_state: st.caption("Related MeSH terms") if st.session_state["results"]: retrieved_tags = [tag for res in st.session_state["results"] for tag in res["tags"]] most_common_tags = [tag for tag, _ in Counter(retrieved_tags).most_common(20)] else: most_common_tags = difflib.get_close_matches(query, tags, n=20) columns = st.columns(5) for row_i in range(3): for col_i, col in enumerate(columns): with col: tag_i = row_i * 5 + col_i if tag_i < len(most_common_tags): tag = most_common_tags[tag_i] st.button(tag, on_click=search, kwargs={"query": tag}) results = st.session_state["results"] st.caption(f"Found {len(results)}. Displaying {nb_results}") st.download_button( "Download results", data=pd.DataFrame(results).to_csv(), file_name="results.csv", mime="text/csv", ) st.table(results[:nb_results])