File size: 2,423 Bytes
e18db08
 
b493a01
4709571
28fedac
4709571
 
b493a01
4709571
 
 
 
b493a01
4709571
 
b493a01
4709571
 
b493a01
 
 
462c96e
b493a01
 
 
c956188
 
 
 
4709571
 
 
 
 
 
e18db08
 
 
 
 
b493a01
4709571
 
 
 
b493a01
e18db08
 
 
 
 
b493a01
4709571
 
 
 
b493a01
 
 
4709571
28fedac
 
c956188
 
 
 
 
 
28fedac
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
import difflib

from collections import Counter
import streamlit as st
import pandas as pd
import srsly


def search(query):
    results = []
    for grant in grants:
        if query in grant["tags"]:
            results.append({"title": grant["title"], "tags": grant["tags"]})
    st.session_state["results"] = results


st.header("Search πŸ”Ž grants using MeSH πŸ”–")
st.sidebar.header("Information β„Ή")
st.sidebar.write(
    "A complete list of MeSH tags can be found here https://meshb.nlm.nih.gov/treeView"
)
st.sidebar.write("The grants data can be found at [https://www.threesixtygiving.org/](https://data.threesixtygiving.org/). They are published under a [CC BY 4.0](https://creativecommons.org/licenses/by/4.0/) license.")
st.sidebar.write(
    "The model used to tag grants is https://huggingface.co/Wellcome/WellcomeBertMesh"
)
st.sidebar.header("Parameters")
nb_results = st.sidebar.slider(
    "Number of results to display", value=20, min_value=1, max_value=100
)

if "grants" not in st.session_state:
    st.session_state["grants"] = list(srsly.read_jsonl("tagged_grants.jsonl"))

grants = st.session_state["grants"]

if "tags" not in st.session_state:
    st.session_state["tags"] = list(set([tag for grant in grants for tag in grant["tags"]]))

tags = st.session_state["tags"]

query = st.text_input("", value="Malaria")
st.button("Search πŸ”Ž", on_click=search, kwargs={"query": query})

if "results" in st.session_state:
    st.caption("Related MeSH terms")

    if st.session_state["results"]:
        retrieved_tags = [tag for res in st.session_state["results"] for tag in res["tags"]]
        most_common_tags = [tag for tag, _ in Counter(retrieved_tags).most_common(20)]
    else:
        most_common_tags = difflib.get_close_matches(query, tags, n=20)

    columns = st.columns(5)
    for row_i in range(3):
        for col_i, col in enumerate(columns):
            with col:
                tag_i = row_i * 5 + col_i
                if tag_i < len(most_common_tags):
                    tag = most_common_tags[tag_i]
                    st.button(tag, on_click=search, kwargs={"query": tag})
    results = st.session_state["results"]
    st.caption(f"Found {len(results)}. Displaying {nb_results}")
    st.download_button(
        "Download results",
        data=pd.DataFrame(results).to_csv(),
        file_name="results.csv",
        mime="text/csv",
    )
    st.table(results[:nb_results])