Spaces:
Sleeping
Sleeping
import json | |
import os | |
import time | |
import matplotlib | |
import numpy as np | |
import pandas as pd | |
import streamlit as st | |
from sentence_transformers import SentenceTransformer | |
from sqlalchemy import create_engine, text | |
from streamlit_agraph import Config, Edge, Node, agraph | |
from llm_res import get_short_summary_out_of_json_files, tagging_insights_from_json | |
from utils import ( | |
augment_the_set_of_diseaces, | |
filter_out_less_promising_diseases, | |
get_all_diseases_name, | |
get_clinical_records_by_ids, | |
get_clinical_trials_related_to_diseases, | |
get_diseases_related_to_a_textual_description, | |
get_most_similar_diseases_from_uri, | |
get_similarities_among_diseases_uris, | |
get_similarities_df, | |
get_uri_from_name, | |
render_trial_details, | |
get_labels_of_diseases_from_uris, | |
) | |
# variables to reveal next steps | |
show_graph = False | |
show_analyze_status = False | |
show_overview = False | |
show_details = False | |
show_metrics = False | |
# IRIS connection | |
username = "demo" | |
password = "demo" | |
hostname = os.getenv("IRIS_HOSTNAME", "localhost") | |
port = "1972" | |
namespace = "USER" | |
CONNECTION_STRING = f"iris://{username}:{password}@{hostname}:{port}/{namespace}" | |
engine = create_engine(CONNECTION_STRING) | |
st.image("img_klinic.jpeg", caption="(AI-generated image)", use_column_width=True) | |
st.title("Klìnic", help="AI-powered clinical trial search engine") | |
st.subheader( | |
"Find clinical trials in a scoped domain of biomedical research, guiding your research with AI-powered insights." | |
) | |
with st.container(): # user input | |
col1, col2 = st.columns((6, 1)) | |
with col1: | |
description_input = st.text_area( | |
label="Enter a disease description 👇", | |
placeholder="A disorder manifested in memory loss and other cognitive impairments among elderly patients (60+ years old), especially women.", | |
) | |
with col2: | |
st.text("") # dummy to center vertically | |
st.text("") # dummy to center vertically | |
st.text("") # dummy to center vertically | |
show_analyze_status = st.button("Analyze 🔎") | |
# analyze | |
with st.container(): | |
if show_analyze_status: | |
with st.status("Analyzing...") as status: | |
# 1. Embed the textual description that the user entered using the model | |
# 2. Get 5 diseases with the highest cosine silimarity from the DB | |
status.write("Analyzing the description that you wrote...") | |
encoder = SentenceTransformer("allenai-specter") | |
diseases_related_to_the_user_text = ( | |
get_diseases_related_to_a_textual_description( | |
description_input, encoder | |
) | |
) | |
status.info( | |
f"Selected {len(diseases_related_to_the_user_text)} diseases related to the description you entered." | |
) | |
status.json(diseases_related_to_the_user_text, expanded=False) | |
status.divider() | |
# 3. Get the similarities of the embeddings of those diseases (cosine similarity of the embeddings of the nodes of such diseases) | |
status.write( | |
"Getting the similarities among the diseases to filter out less promising ones..." | |
) | |
diseases_uris = [ | |
disease["uri"] for disease in diseases_related_to_the_user_text | |
] | |
similarities = get_similarities_among_diseases_uris(diseases_uris) | |
status.info( | |
f"Obtained similarity information among the diseases by measuring the cosine similarity of their embeddings." | |
) | |
status.json(similarities, expanded=False) | |
filtered_diseases_uris, df_similarities = ( | |
filter_out_less_promising_diseases(similarities) | |
) | |
# Apply a colormap to the table | |
status.table( | |
df_similarities.style.background_gradient(cmap="viridis", axis=None) | |
) | |
status.info( | |
f"Filtered out less promising diseases, keeping {len(filtered_diseases_uris)} diseases." | |
) | |
status.json(filtered_diseases_uris, expanded=False) | |
status.divider() | |
# 4. Potentially filter out the diseases that are not similar enough (e.g. similarity < 0.8) | |
# 5. Augment the set of diseases: add new diseases that are similar to the ones that are already in the set, until we get 10-15 diseases | |
status.write( | |
"Augmenting the set of diseases by finding others with related embeddings..." | |
) | |
augmented_set_of_diseases = augment_the_set_of_diseaces(filtered_diseases_uris) | |
similarities_of_augmented_set_of_diseases = ( | |
get_similarities_among_diseases_uris(augmented_set_of_diseases) | |
) | |
df_similarities_augmented_set = get_similarities_df( | |
similarities_of_augmented_set_of_diseases | |
) | |
#status.json(similarities_of_augmented_set_of_diseases, expanded=True) | |
status.info( | |
f"Augmented set of diseases: {len(augmented_set_of_diseases)} diseases." | |
) | |
status.table( | |
df_similarities_augmented_set.style.background_gradient(cmap="viridis", axis=None) | |
) | |
status.json(augmented_set_of_diseases, expanded=False) | |
status.divider() | |
# 6. Query the embeddings of the diseases related to each clinical trial (also in the DB), to get the most similar clinical trials to our set of diseases | |
status.write("Getting the clinical trials related to the diseases found...") | |
clinical_trials_related_to_the_diseases = ( | |
get_clinical_trials_related_to_diseases( | |
augmented_set_of_diseases, encoder | |
) | |
) | |
status.info( | |
f"Selected {len(clinical_trials_related_to_the_diseases)} clinical trials related to the diseases." | |
) | |
status.json(clinical_trials_related_to_the_diseases, expanded=False) | |
status.divider() | |
status.write("Getting the details of the clinical trials...") | |
json_of_clinical_trials = get_clinical_records_by_ids( | |
[trial["nct_id"] for trial in clinical_trials_related_to_the_diseases] | |
) | |
status.success(f"Details of the clinical trials obtained.") | |
status.json(json_of_clinical_trials, expanded=False) | |
status.divider() | |
# 7. Use an LLM to get a summary of the clinical trials, in plain text format. | |
try: | |
status.write("Getting a summary of the clinical trials...") | |
response = get_short_summary_out_of_json_files(json_of_clinical_trials) | |
status.success("Summary of the clinical trials obtained.") | |
disease_overview = response | |
except Exception as e: | |
print(f"Error while getting a summary of the clinical trials: {e}") | |
status.warning( | |
f"Error while getting a summary of the clinical trials. This information will not be shown." | |
) | |
try: | |
# 8. Use an LLM to extract numerical data from the clinical trials (e.g. number of patients, number of deaths, etc.). Get summary statistics out of that. | |
status.write("Getting summary statistics of the clinical trials...") | |
response = tagging_insights_from_json(json_of_clinical_trials) | |
average_minimum_age = response["avg_min_age"] | |
average_maximum_age = response["avg_max_age"] | |
most_common_gender = response["most_common_gender"] | |
print(f"Response from LLM tagging: {response}") | |
status.success(f"Summary statistics of the clinical trials obtained.") | |
except Exception as e: | |
print( | |
f"Error while extracting numerical data from the clinical trials: {e}" | |
) | |
status.warning( | |
f"Error while extracting numerical data from the clinical trials. This information will not be shown." | |
) | |
# 9. Show the results to the user: graph of the diseases chosen, summary of the clinical trials, summary statistics of the clinical trials, and list of the details of the clinical trials considered | |
status.update(label="Done!", state="complete") | |
status.balloons() | |
show_graph = True | |
trials = json_of_clinical_trials | |
# graph | |
with st.container(): | |
if show_graph: | |
st.info( | |
"""This is a graph of the relevant diseases that we found, based on the description that you entered. The diseases are connected by edges if they are similar to each other. The color of the edges represents the similarity of the diseases. | |
We use the embeddings of the diseases to determine the similarity between them. The embeddings are generated using a Representation Learning algorithm that learns the topological relations among the nodes in the graph, depending on how they are connected. We utilize the [PyKeen](https://github.com/pykeen/pykeen) implementation of TransH to train an embedding model. | |
[TransH](https://ojs.aaai.org/index.php/AAAI/article/view/8870) utilizes hyperplanes to model relations between entities. It is a multi-relational model that can handle many-to-many relations between entities. The model is trained on the triples of the graph, where the triples are the subject, relation, and object of the graph. The model learns the embeddings of the entities and the relations, such that the embeddings of the subject and object are close to each other when the relation is true. | |
Specifically, it optimizes the following cost function: | |
$\\text{minimize} \\sum_{(h, r, t) \\in S} \\max(0, \\gamma + f(h, r, t) - f(h, r, t')) + \\sum_{(h, r, t) \\in S'} f(h, r, t)$ | |
By minimizing this cost function, the model learns the embeddings of the entities and relations that best represent the graph. The embeddings are then used to calculate the similarity between the diseases, which is shown in the graph. | |
""" | |
) | |
try: | |
print(f'df_similarities_augmented_set.index: {df_similarities_augmented_set.index}') | |
edges_to_show = [] | |
labels_of_diseases = get_labels_of_diseases_from_uris( | |
augmented_set_of_diseases | |
) | |
print(f'labels_of_diseases: {labels_of_diseases}') | |
uris_and_labels_of_diseases = dict( | |
zip(df_similarities_augmented_set.index, labels_of_diseases) | |
) | |
print(f'uris_and_labels_of_diseases: {uris_and_labels_of_diseases}') | |
color_mapper = matplotlib.cm.get_cmap("viridis") | |
for source in df_similarities_augmented_set.index: | |
for target in df_similarities_augmented_set.columns: | |
if source != target: | |
weight = df_similarities_augmented_set.loc[source, target] | |
color = color_mapper(weight) | |
# Convert from rgba to hex | |
color = matplotlib.colors.to_hex(color) | |
edges_to_show.append( | |
Edge( | |
source=source, | |
target=target, | |
# Dynamic color based on the weight | |
color=color, | |
weight=weight**10, | |
type="CURVE_SMOOTH", | |
label=f"{weight:.2f}", | |
) | |
) | |
graph_of_diseases = agraph( | |
nodes=[ | |
Node( | |
id=disease, | |
# If it's nan then use the URI | |
label=uris_and_labels_of_diseases[disease] if (not pd.isna(uris_and_labels_of_diseases[disease]) and uris_and_labels_of_diseases[disease] != "nan") else disease, | |
size=50, | |
shape="circular", | |
) | |
for disease in df_similarities_augmented_set.index | |
], | |
edges=edges_to_show, | |
config=Config(height=500, width=500), | |
) | |
time.sleep(2) | |
except Exception as e: | |
print(f"Error while showing the graph of the diseases: {e}") | |
st.error("Error while showing the graph of the diseases.") | |
finally: | |
show_overview = True | |
# overview | |
with st.container(): | |
if show_overview: | |
try: | |
st.write("## Overview of Related Clinical Trials") | |
st.write(disease_overview) | |
time.sleep(1) | |
except Exception as e: | |
print(f"Error while showing the overview of the clinical trials: {e}") | |
finally: | |
show_metrics = True | |
with st.container(): | |
if show_metrics: | |
try: | |
st.write("## Metrics of the Clinical Trials") | |
col1, col2, col3 = st.columns(3) | |
with col1: | |
st.metric("Average Minimum Age", average_minimum_age) | |
with col2: | |
st.metric("Average Maximum Age", average_maximum_age) | |
with col3: | |
st.metric("Most Common Gender", most_common_gender) | |
time.sleep(2) | |
except Exception as e: | |
print(f"Error while showing the metrics: {e}") | |
finally: | |
show_details = True | |
# details | |
with st.container(): | |
if show_details: | |
st.write("## Clinical Trials Details") | |
tab_titles = [ | |
f"{trial['protocolSection']['identificationModule']['nctId']}" | |
for trial in trials | |
] | |
tabs = st.tabs(tab_titles) | |
for i in range(0, len(tabs)): | |
with tabs[i]: | |
render_trial_details(trials[i]) | |
st.divider() | |
st.markdown( | |
"""This app has been created in HackUPC 2024 by the team 'Klìnic'. The team members are: | |
- [Aldan Creo](https://acmc-website.web.app) | |
- [Matthias Seiler](https://www.linkedin.com/in/maseiler/) | |
- [Tanguyvans Vansnick](https://www.linkedin.com/in/tanguy-vansnick-44186a199/) | |
- [Arjit Samal](https://www.linkedin.com/in/arijit-samal1/) | |
""" | |
) | |
show_graph_of_all_diseases = False | |
if show_graph_of_all_diseases: | |
# If disease_names is not defined, define it | |
if "disease_names" not in st.session_state: | |
st.session_state.disease_names = get_all_diseases_name(engine) | |
chosen_disease_name = st.selectbox( | |
"Choose a disease", | |
st.session_state.disease_names, | |
) | |
st.write("You selected:", chosen_disease_name) | |
chosen_disease_uri = get_uri_from_name(engine, chosen_disease_name) | |
nodes = [] | |
edges = [] | |
nodes.append( | |
Node( | |
id=chosen_disease_uri, label=chosen_disease_name, size=25, shape="circular" | |
) | |
) | |
similar_diseases = get_most_similar_diseases_from_uri( | |
engine, chosen_disease_uri, threshold=0.6 | |
) | |
print(similar_diseases) | |
for uri, name, weight in similar_diseases: | |
nodes.append(Node(id=uri, label=name, size=25, shape="circular")) | |
print(True if float(weight) > 0.7 else False) | |
edges.append( | |
Edge( | |
source=chosen_disease_uri, | |
target=uri, | |
color="red" if float(weight) > 0.7 else "blue", | |
weight=float(weight) ** 10, | |
type="CURVE_SMOOTH", | |
# type="STRAIGHT" | |
) | |
) | |
config = Config( | |
width=750, | |
height=950, | |
directed=False, | |
physics=True, | |
hierarchical=False, | |
collapsible=False, | |
# **kwargs | |
) | |
return_value = agraph(nodes=nodes, edges=edges, config=config) | |