ayushnoori commited on
Commit
6efe11e
·
1 Parent(s): f18a5e1

Update pfp

Browse files
.gitattributes CHANGED
@@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ media/pfp/*.png filter=lfs diff=lfs merge=lfs -text
app.py CHANGED
@@ -121,7 +121,8 @@ def check_password():
121
  # Retrieve and store user name and team
122
  st.session_state["name"] = user_db.loc[user_db.username == st.session_state["username"], "name"].values[0]
123
  st.session_state["team"] = user_db.loc[user_db.username == st.session_state["username"], "team"].values[0]
124
- st.session_state["profile_pic"] = user_db.loc[user_db.username == st.session_state["username"], "profile_pic"].values[0]
 
125
 
126
  # Don't store the username or password
127
  del st.session_state["password"]
 
121
  # Retrieve and store user name and team
122
  st.session_state["name"] = user_db.loc[user_db.username == st.session_state["username"], "name"].values[0]
123
  st.session_state["team"] = user_db.loc[user_db.username == st.session_state["username"], "team"].values[0]
124
+ # st.session_state["profile_pic"] = user_db.loc[user_db.username == st.session_state["username"], "profile_pic"].values[0]
125
+ st.session_state["profile_pic"] = st.session_state["username"]
126
 
127
  # Don't store the username or password
128
  del st.session_state["password"]
media/pfp/anoori.png ADDED

Git LFS Details

  • SHA256: 56f2cd51f6496ff1e43f0ce3fb63145a442772b16e3d456bba06cf86d78671cf
  • Pointer size: 132 Bytes
  • Size of remote file: 1.53 MB
media/pfp/bbudnik.png ADDED

Git LFS Details

  • SHA256: 9b6822130cf0db1a934dbcfa55ca0a2d0787f41a18ef741cae9d65106d3e0db7
  • Pointer size: 132 Bytes
  • Size of remote file: 1.56 MB
media/pfp/iarango.png ADDED

Git LFS Details

  • SHA256: 467539db9c9fd915c72d161c7e4b09d083ee950efb5415a4a81bb77235a8fea0
  • Pointer size: 132 Bytes
  • Size of remote file: 1.17 MB
media/pfp/jtam.png ADDED

Git LFS Details

  • SHA256: 9620fe17a9374b0eaa95c48e7d6f8cf7538d29b62b795e4d8a1dba804c6a274f
  • Pointer size: 132 Bytes
  • Size of remote file: 1.19 MB
media/pfp/kmeyer.png ADDED

Git LFS Details

  • SHA256: 1269def5de603b045dad6410561544e7d42fa85b66321634a8d9979ef3c582ac
  • Pointer size: 132 Bytes
  • Size of remote file: 1.33 MB
media/pfp/mzitnik.png ADDED

Git LFS Details

  • SHA256: b514858118909ce8004028a1f87f3e7a259415d730d34371830e884a1343da2f
  • Pointer size: 132 Bytes
  • Size of remote file: 1.2 MB
media/pfp/ndagan.png ADDED

Git LFS Details

  • SHA256: d7d169b3a4cceca7bcb829ae02ea5ce912c94b5b66006343f1d5bfdc7296ce79
  • Pointer size: 132 Bytes
  • Size of remote file: 1.07 MB
media/pfp/nliu.png ADDED

Git LFS Details

  • SHA256: 2e71d37c3512b91f782871e612c26ad0034f86536dfb97c9e27fd8979ba49f85
  • Pointer size: 132 Bytes
  • Size of remote file: 1.34 MB
media/pfp/rbalicer.png ADDED

Git LFS Details

  • SHA256: c1b68144f018798483fb7485d9fa46350e6274bc9ece3d5548ec33226ca2690d
  • Pointer size: 131 Bytes
  • Size of remote file: 580 kB
media/pfp/vkhurana.png ADDED

Git LFS Details

  • SHA256: d063af30f22a6aeb948dad3afeaeba705f5cbed9a60f6d911b3ccb870adf5e31
  • Pointer size: 132 Bytes
  • Size of remote file: 1.04 MB
menu.py CHANGED
@@ -1,52 +1,21 @@
1
  # From https://docs.streamlit.io/develop/tutorials/multipage/st.page_link-nav
2
  import streamlit as st
3
-
4
 
5
  def authenticated_menu():
6
 
7
- st.markdown("""
8
- <style>
9
- .circle-image {
10
- width: 100px;
11
- height: 100px;
12
- border-radius: 50%;
13
- overflow: hidden;
14
- display: flex;
15
- justify-content: center;
16
- align-items: center;
17
- border: 2px solid black;
18
- margin: 0 auto 10px auto;
19
- }
20
-
21
- .circle-image img {
22
- width: 100%;
23
- height: 100%;
24
- object-fit: cover;
25
- }
26
-
27
- .username {
28
- font-size: 20px;
29
- font-weight: bold;
30
- text-align: center;
31
- margin-top: 0px;
32
- }
33
- </style>
34
- """, unsafe_allow_html=True)
35
-
36
- # Show the user's profile picture
37
- st.sidebar.markdown(f'<div class="circle-image"><img src="{st.session_state.profile_pic}" /></div>', unsafe_allow_html=True)
38
-
39
- # Show the user's name
40
- # st.sidebar.markdown(f"Logged in as {st.session_state.name}.")
41
- st.sidebar.markdown(f'<div class="username">{st.session_state.name}</div>', unsafe_allow_html=True)
42
  st.sidebar.markdown("---")
43
 
44
  # Show a navigation menu for authenticated users
45
  # st.sidebar.page_link("app.py", label="Switch Accounts", icon="🔒")
46
  st.sidebar.page_link("pages/about.py", label="About", icon="📖")
47
  st.sidebar.page_link("pages/input.py", label="Input", icon="💡")
48
- st.sidebar.page_link("pages/predict.py", label="Predict", icon="🔍")
49
- st.sidebar.page_link("pages/validate.py", label="Validate", icon="✅")
 
 
50
  # st.sidebar.page_link("pages/explore.py", label="Explore", icon="🔍")
51
  if st.session_state.role in ["admin"]:
52
  st.sidebar.page_link("pages/admin.py", label="Manage Users", icon="🔧")
@@ -54,7 +23,6 @@ def authenticated_menu():
54
  # Show the logout button
55
  st.sidebar.markdown("---")
56
  st.sidebar.button("Log Out", on_click=lambda: st.session_state.clear())
57
-
58
 
59
 
60
  def unauthenticated_menu():
 
1
  # From https://docs.streamlit.io/develop/tutorials/multipage/st.page_link-nav
2
  import streamlit as st
3
+ import project_config
4
 
5
  def authenticated_menu():
6
 
7
+ # Insert profile picture
8
+ st.sidebar.image(str(project_config.MEDIA_DIR / 'pfp' / f"{st.session_state.profile_pic}.png"), use_column_width=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9
  st.sidebar.markdown("---")
10
 
11
  # Show a navigation menu for authenticated users
12
  # st.sidebar.page_link("app.py", label="Switch Accounts", icon="🔒")
13
  st.sidebar.page_link("pages/about.py", label="About", icon="📖")
14
  st.sidebar.page_link("pages/input.py", label="Input", icon="💡")
15
+ st.sidebar.page_link("pages/predict.py", label="Predict", icon="🔍",
16
+ disabled=("query" not in st.session_state))
17
+ st.sidebar.page_link("pages/validate.py", label="Validate", icon="✅",
18
+ disabled=("query" not in st.session_state))
19
  # st.sidebar.page_link("pages/explore.py", label="Explore", icon="🔍")
20
  if st.session_state.role in ["admin"]:
21
  st.sidebar.page_link("pages/admin.py", label="Manage Users", icon="🔧")
 
23
  # Show the logout button
24
  st.sidebar.markdown("---")
25
  st.sidebar.button("Log Out", on_click=lambda: st.session_state.clear())
 
26
 
27
 
28
  def unauthenticated_menu():
pages/about.py CHANGED
@@ -18,3 +18,7 @@ st.markdown(f"Hello, {st.session_state.name}! Welcome to GRAVITY, a **GR**aph **
18
 
19
  # Subheader
20
  st.subheader("About GRAVITY", divider = "grey")
 
 
 
 
 
18
 
19
  # Subheader
20
  st.subheader("About GRAVITY", divider = "grey")
21
+
22
+ st.markdown("""
23
+ Knowledge graphs (KGs) are data structures that use network topology to represent relational information, including and especially in biology and medicine. Graph artificial intelligence (AI) models trained on these biomedical KGs can enable many important link prediction tasks, such as predicting disease progression, diagnosing genetic disorders, identifying therapeutic targets, and discovering new drugs. However, especially in biomedical settings, it is important for clinicians and scientists to evaluate whether KG-grounded AI models are safe and trustworthy, and whether the predictions of these models are biologically explainable. To address this challenge, we developed GRAVITY, an interactive user interface for graph-based explainable AI. GRAVITY enables human users to query and interpret KG-grounded AI models for biomedical link prediction tasks.
24
+ """)
pages/input.py CHANGED
@@ -10,6 +10,7 @@ from pathlib import Path
10
 
11
  # Custom and other imports
12
  import project_config
 
13
 
14
  # Redirect to app.py if not logged in, otherwise show the navigation menu
15
  menu_with_redirect()
@@ -25,8 +26,11 @@ st.subheader("Construct Query", divider = "red")
25
  # Checkbox to allow reverse edges
26
  allow_reverse_edges = st.checkbox("Reverse Edges", value = False)
27
 
 
 
 
28
  with st.spinner('Loading knowledge graph...'):
29
- kg_nodes = nodes = pd.read_csv(project_config.DATA_DIR / 'kg_nodes.csv', dtype = {'node_index': int}, low_memory = False)
30
  node_types = pd.read_csv(project_config.DATA_DIR / 'kg_node_types.csv')
31
  edge_types = pd.read_csv(project_config.DATA_DIR / 'kg_edge_types.csv')
32
 
@@ -64,6 +68,10 @@ if st.button("Submit Query"):
64
  # st.write(st.session_state.query)
65
  st.write("Query submitted.")
66
 
 
 
 
 
67
  st.subheader("Knowledge Graph", divider = "red")
68
  display_data = kg_nodes[['node_id', 'node_type', 'node_name', 'node_source']].copy()
69
  display_data = display_data.rename(columns = {'node_id': 'ID', 'node_type': 'Type', 'node_name': 'Name', 'node_source': 'Database'})
 
10
 
11
  # Custom and other imports
12
  import project_config
13
+ from utils import load_kg
14
 
15
  # Redirect to app.py if not logged in, otherwise show the navigation menu
16
  menu_with_redirect()
 
26
  # Checkbox to allow reverse edges
27
  allow_reverse_edges = st.checkbox("Reverse Edges", value = False)
28
 
29
+ # Load knowledge graph
30
+ kg_nodes = load_kg()
31
+
32
  with st.spinner('Loading knowledge graph...'):
33
+ # kg_nodes = nodes = pd.read_csv(project_config.DATA_DIR / 'kg_nodes.csv', dtype = {'node_index': int}, low_memory = False)
34
  node_types = pd.read_csv(project_config.DATA_DIR / 'kg_node_types.csv')
35
  edge_types = pd.read_csv(project_config.DATA_DIR / 'kg_edge_types.csv')
36
 
 
68
  # st.write(st.session_state.query)
69
  st.write("Query submitted.")
70
 
71
+ # Switch to the Predict page
72
+ st.switch_page("pages/predict.py")
73
+
74
+
75
  st.subheader("Knowledge Graph", divider = "red")
76
  display_data = kg_nodes[['node_id', 'node_type', 'node_name', 'node_source']].copy()
77
  display_data = display_data.rename(columns = {'node_id': 'ID', 'node_type': 'Type', 'node_name': 'Name', 'node_source': 'Database'})
pages/predict.py CHANGED
@@ -14,7 +14,7 @@ from huggingface_hub import hf_hub_download
14
 
15
  # Custom and other imports
16
  import project_config
17
- from utils import capitalize_after_slash
18
 
19
  # Redirect to app.py if not logged in, otherwise show the navigation menu
20
  menu_with_redirect()
@@ -28,13 +28,12 @@ st.image(str(project_config.MEDIA_DIR / 'predict_header.svg'), use_column_width=
28
  st.subheader(f"{capitalize_after_slash(st.session_state.query['target_node_type'])} Search", divider = "blue")
29
 
30
  # Print current query
31
- st.markdown(f"**Query:** {st.session_state.query['source_node']} ➡️ {st.session_state.query['relation']} ➡️ {st.session_state.query['target_node_type']}")
32
 
33
- with st.spinner('Loading knowledge graph...'):
34
- kg_nodes = nodes = pd.read_csv(project_config.DATA_DIR / 'kg_nodes.csv', dtype = {'node_index': int}, low_memory = False)
35
-
36
- # Get paths to embeddings, relation weights, and edge types
37
- with st.spinner('Downloading AI model...'):
38
  embed_path = hf_hub_download(repo_id="ayushnoori/galaxy",
39
  filename="2024_03_29_04_12_52_epoch=3-step=54291_embeddings.pt",
40
  token=st.secrets["HF_TOKEN"])
@@ -44,13 +43,23 @@ with st.spinner('Downloading AI model...'):
44
  edge_types_path = hf_hub_download(repo_id="ayushnoori/galaxy",
45
  filename="2024_03_29_04_12_52_epoch=3-step=54291_edge_types.pt",
46
  token=st.secrets["HF_TOKEN"])
 
47
 
48
- # Load embeddings, relation weights, and edge types
49
- with st.spinner('Loading AI model...'):
 
 
50
  embeddings = torch.load(embed_path)
51
  relation_weights = torch.load(relation_weights_path)
52
  edge_types = torch.load(edge_types_path)
53
 
 
 
 
 
 
 
 
54
  # # Print source node type
55
  # st.write(f"Source Node Type: {st.session_state.query['source_node_type']}")
56
 
@@ -78,7 +87,7 @@ with st.spinner('Computing predictions...'):
78
  edge_type_index = [i for i, etype in enumerate(edge_types) if etype == (source_node_type, relation, target_node_type)][0]
79
 
80
  # Get target nodes indices
81
- target_nodes = kg_nodes[kg_nodes.node_type == target_node_type]
82
  dst_indices = target_nodes.node_index.values
83
  src_indices = np.repeat(src_index, len(dst_indices))
84
 
@@ -126,7 +135,8 @@ with st.spinner('Computing predictions...'):
126
  display_data['Database'] = display_data.apply(lambda x: map_dbs[target_node_type](x['ID']), axis = 1)
127
 
128
  # Use multiselect to search for specific nodes
129
- selected_nodes = st.multiselect('Search for specific nodes.', display_data.Name)
 
130
 
131
  # Filter nodes
132
  if len(selected_nodes) > 0:
@@ -152,3 +162,6 @@ with st.spinner('Computing predictions...'):
152
  display_text = display_database)})
153
  else:
154
  st.dataframe(display_data.iloc[:top_k], use_container_width = True)
 
 
 
 
14
 
15
  # Custom and other imports
16
  import project_config
17
+ from utils import capitalize_after_slash, load_kg
18
 
19
  # Redirect to app.py if not logged in, otherwise show the navigation menu
20
  menu_with_redirect()
 
28
  st.subheader(f"{capitalize_after_slash(st.session_state.query['target_node_type'])} Search", divider = "blue")
29
 
30
  # Print current query
31
+ st.markdown(f"**Query:** {st.session_state.query['source_node'].replace('_', ' ')} ➡️ {st.session_state.query['relation'].replace('_', '-')} ➡️ {st.session_state.query['target_node_type'].replace('_', ' ')}")
32
 
33
+ @st.cache_data(show_spinner = 'Downloading AI model...')
34
+ def get_embeddings():
35
+ # Get paths to embeddings, relation weights, and edge types
36
+ # with st.spinner('Downloading AI model...'):
 
37
  embed_path = hf_hub_download(repo_id="ayushnoori/galaxy",
38
  filename="2024_03_29_04_12_52_epoch=3-step=54291_embeddings.pt",
39
  token=st.secrets["HF_TOKEN"])
 
43
  edge_types_path = hf_hub_download(repo_id="ayushnoori/galaxy",
44
  filename="2024_03_29_04_12_52_epoch=3-step=54291_edge_types.pt",
45
  token=st.secrets["HF_TOKEN"])
46
+ return embed_path, relation_weights_path, edge_types_path
47
 
48
+ @st.cache_data(show_spinner = 'Loading AI model...')
49
+ def load_embeddings(embed_path, relation_weights_path, edge_types_path):
50
+ # Load embeddings, relation weights, and edge types
51
+ # with st.spinner('Loading AI model...'):
52
  embeddings = torch.load(embed_path)
53
  relation_weights = torch.load(relation_weights_path)
54
  edge_types = torch.load(edge_types_path)
55
 
56
+ return embeddings, relation_weights, edge_types
57
+
58
+ # Load knowledge graph and embeddings
59
+ kg_nodes = load_kg()
60
+ embed_path, relation_weights_path, edge_types_path = get_embeddings()
61
+ embeddings, relation_weights, edge_types = load_embeddings(embed_path, relation_weights_path, edge_types_path)
62
+
63
  # # Print source node type
64
  # st.write(f"Source Node Type: {st.session_state.query['source_node_type']}")
65
 
 
87
  edge_type_index = [i for i, etype in enumerate(edge_types) if etype == (source_node_type, relation, target_node_type)][0]
88
 
89
  # Get target nodes indices
90
+ target_nodes = kg_nodes[kg_nodes.node_type == target_node_type].copy()
91
  dst_indices = target_nodes.node_index.values
92
  src_indices = np.repeat(src_index, len(dst_indices))
93
 
 
135
  display_data['Database'] = display_data.apply(lambda x: map_dbs[target_node_type](x['ID']), axis = 1)
136
 
137
  # Use multiselect to search for specific nodes
138
+ selected_nodes = st.multiselect(f"Search for specific {st.session_state.query['source_node_type'].replace('_', ' ')} nodes to determine their ranking.",
139
+ display_data.Name, placeholder = "Type to search...")
140
 
141
  # Filter nodes
142
  if len(selected_nodes) > 0:
 
162
  display_text = display_database)})
163
  else:
164
  st.dataframe(display_data.iloc[:top_k], use_container_width = True)
165
+
166
+ # Save to session state
167
+ st.session_state.predictions = display_data
pages/validate.py CHANGED
@@ -18,5 +18,9 @@ st.image(str(project_config.MEDIA_DIR / 'validate_header.svg'), use_column_width
18
 
19
  st.subheader("Validate Predictions", divider = "green")
20
 
 
 
 
 
21
  # Coming soon
22
  st.write("Coming soon...")
 
18
 
19
  st.subheader("Validate Predictions", divider = "green")
20
 
21
+ # Print current query
22
+ st.markdown(f"**Query:** {st.session_state.query['source_node'].replace('_', ' ')} ➡️ {st.session_state.query['relation'].replace('_', '-')} ➡️ {st.session_state.query['target_node_type'].replace('_', ' ')}")
23
+
24
+
25
  # Coming soon
26
  st.write("Coming soon...")
utils.py CHANGED
@@ -1,5 +1,13 @@
1
- import base64
2
  import streamlit as st
 
 
 
 
 
 
 
 
 
3
 
4
  def capitalize_after_slash(s):
5
  # Split the string by slashes first
@@ -52,4 +60,98 @@ def add_logo(png_file):
52
  st.markdown(
53
  logo_markup,
54
  unsafe_allow_html=True,
55
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import streamlit as st
2
+ import pandas as pd
3
+ import project_config
4
+ import base64
5
+
6
+ @st.cache_data(show_spinner = 'Loading knowledge graph...')
7
+ def load_kg():
8
+ # with st.spinner('Loading knowledge graph...'):
9
+ kg_nodes = pd.read_csv(project_config.DATA_DIR / 'kg_nodes.csv', dtype = {'node_index': int}, low_memory = False)
10
+ return kg_nodes
11
 
12
  def capitalize_after_slash(s):
13
  # Split the string by slashes first
 
60
  st.markdown(
61
  logo_markup,
62
  unsafe_allow_html=True,
63
+ )
64
+
65
+
66
+ # @st.cache_resource()
67
+ # def generate_profile_pic():
68
+
69
+ # st.markdown("""
70
+ # <style>
71
+ # .circle-image {
72
+ # width: 100px;
73
+ # height: 100px;
74
+ # border-radius: 50%;
75
+ # overflow: hidden;
76
+ # display: flex;
77
+ # justify-content: center;
78
+ # align-items: center;
79
+ # border: 2px solid black;
80
+ # margin: 0 auto 10px auto;
81
+ # }
82
+
83
+ # .circle-image img {
84
+ # width: 100%;
85
+ # max-width: 100px;
86
+ # height: 100%;
87
+ # object-fit: cover;
88
+ # }
89
+
90
+ # .username {
91
+ # font-size: 20px;
92
+ # font-weight: bold;
93
+ # text-align: center;
94
+ # margin-top: 0px;
95
+ # }
96
+ # </style>
97
+ # """, unsafe_allow_html=True)
98
+
99
+ # # Show the user's profile picture
100
+ # st.sidebar.html(f'<div class="circle-image"><img src="{st.session_state.profile_pic}" /></div>')
101
+
102
+ # # Show the user's name
103
+ # st.sidebar.html(f'<div class="username">{st.session_state.name}</div>')
104
+
105
+ # return None
106
+
107
+ # # Load image using PIL
108
+ # from PIL import Image, ImageDraw, ImageOps
109
+ # from io import BytesIO
110
+ # import requests
111
+
112
+ # def PIL_profile_pic():
113
+
114
+ # # Load user profile picture
115
+ # profile_pic = st.session_state.profile_pic
116
+ # response = requests.get(profile_pic)
117
+ # img = Image.open(BytesIO(response.content))
118
+
119
+ # # Create a circular mask
120
+ # min_dimension = min(img.size)
121
+ # mask = Image.new('L', (min_dimension, min_dimension), 0)
122
+ # draw = ImageDraw.Draw(mask)
123
+ # draw.ellipse((0, 0, min_dimension, min_dimension), fill=255)
124
+
125
+ # # Crop the image to a square of the smallest dimension
126
+ # left = (img.width - min_dimension) // 2
127
+ # top = (img.height - min_dimension) // 2
128
+ # right = (img.width + min_dimension) // 2
129
+ # bottom = (img.height + min_dimension) // 2
130
+ # img_cropped = img.crop((left, top, right, bottom))
131
+
132
+ # # Apply the circular mask to the cropped image
133
+ # img_circular = ImageOps.fit(img_cropped, (min_dimension, min_dimension))
134
+ # img_circular.putalpha(mask)
135
+
136
+ # st.markdown(
137
+ # """
138
+ # <style>
139
+ # [data-testid=stSidebar] [data-testid=stImage]{
140
+ # text-align: center;
141
+ # display: block;
142
+ # margin-left: auto;
143
+ # margin-right: auto;
144
+ # width: 100%;
145
+ # }
146
+ # </style>
147
+ # """, unsafe_allow_html=True
148
+ # )
149
+
150
+ # # Display the image
151
+ # st.sidebar.image(img_circular, width=200)
152
+ # st.sidebar.subheader(f"{st.session_state.name}")
153
+
154
+
155
+
156
+ # Generate the user's profile picture
157
+ # generate_profile_pic()