Spaces:
Sleeping
Sleeping
ACMCMC
commited on
Commit
•
e7d7b51
1
Parent(s):
47c6369
UI
Browse files- app.py +16 -13
- llm_res.py +12 -9
- requirements.txt +1 -0
- utils.py +64 -27
app.py
CHANGED
@@ -13,12 +13,14 @@ from utils import (
|
|
13 |
augment_the_set_of_diseaces,
|
14 |
get_clinical_trials_related_to_diseases,
|
15 |
get_clinical_records_by_ids,
|
16 |
-
render_trial_details
|
|
|
17 |
)
|
18 |
from llm_res import get_short_summary_out_of_json_files, tagging_insights_from_json
|
19 |
import json
|
20 |
import numpy as np
|
21 |
from sentence_transformers import SentenceTransformer
|
|
|
22 |
|
23 |
|
24 |
# variables to reveal next steps
|
@@ -71,8 +73,13 @@ with st.container():
|
|
71 |
status.write("Getting the similarities among the diseases to filter out less promising ones...")
|
72 |
diseases_uris = [disease["uri"] for disease in diseases_related_to_the_user_text]
|
73 |
similarities = get_similarities_among_diseases_uris(diseases_uris)
|
74 |
-
status.info(f'Obtained similarity information among the diseases by measuring the cosine similarity of their embeddings.
|
75 |
status.json(similarities, expanded=False)
|
|
|
|
|
|
|
|
|
|
|
76 |
status.divider()
|
77 |
# 4. Potentially filter out the diseases that are not similar enough (e.g. similarity < 0.8)
|
78 |
# 5. Augment the set of diseases: add new diseases that are similar to the ones that are already in the set, until we get 10-15 diseases
|
@@ -80,6 +87,8 @@ with st.container():
|
|
80 |
augmented_set_of_diseases = augment_the_set_of_diseaces(diseases_uris)
|
81 |
# print(augmented_set_of_diseases)
|
82 |
status.info(f'Augmented set of diseases: {len(augmented_set_of_diseases)} diseases.')
|
|
|
|
|
83 |
# 6. Query the embeddings of the diseases related to each clinical trial (also in the DB), to get the most similar clinical trials to our set of diseases
|
84 |
status.write("Getting the clinical trials related to the diseases found...")
|
85 |
clinical_trials_related_to_the_diseases = get_clinical_trials_related_to_diseases(
|
@@ -97,18 +106,19 @@ with st.container():
|
|
97 |
# 7. Use an LLM to get a summary of the clinical trials, in plain text format.
|
98 |
status.write("Getting a summary of the clinical trials...")
|
99 |
response, stats_dict = get_short_summary_out_of_json_files(json_of_clinical_trials)
|
100 |
-
|
101 |
-
print(f'basic_stats_dict:{stats_dict}')
|
102 |
status.write(f'Response from LLM summarization: {response}')
|
103 |
# 8. Use an LLM to extract numerical data from the clinical trials (e.g. number of patients, number of deaths, etc.). Get summary statistics out of that.
|
104 |
status.write("Getting summary statistics of the clinical trials...")
|
105 |
-
response = tagging_insights_from_json(json_of_clinical_trials)
|
|
|
106 |
print(f'Response from LLM tagging: {response}')
|
107 |
status.write(f'Response from LLM tagging: {response}')
|
108 |
# 9. Show the results to the user: graph of the diseases chosen, summary of the clinical trials, summary statistics of the clinical trials, and list of the details of the clinical trials considered
|
109 |
status.update(label="Done!", state="complete")
|
110 |
status.balloons()
|
111 |
show_graph = True
|
|
|
112 |
|
113 |
|
114 |
# graph
|
@@ -158,8 +168,7 @@ $$"""
|
|
158 |
# overview
|
159 |
with st.container():
|
160 |
if show_overview:
|
161 |
-
st.write("##
|
162 |
-
disease_overview = ":red[lorem ipsum]" # TODO
|
163 |
st.write(disease_overview)
|
164 |
time.sleep(2)
|
165 |
show_details = True
|
@@ -169,12 +178,6 @@ with st.container():
|
|
169 |
with st.container():
|
170 |
if show_details:
|
171 |
st.write("## Clinical Trials Details")
|
172 |
-
trials = []
|
173 |
-
# TODO replace mock data
|
174 |
-
with open("mock_trial.json") as f:
|
175 |
-
d = json.load(f)
|
176 |
-
for i in range(0, 8):
|
177 |
-
trials.append(d)
|
178 |
|
179 |
tab_titles = [f"{trial['protocolSection']['identificationModule']['nctId']}" for trial in trials]
|
180 |
|
|
|
13 |
augment_the_set_of_diseaces,
|
14 |
get_clinical_trials_related_to_diseases,
|
15 |
get_clinical_records_by_ids,
|
16 |
+
render_trial_details,
|
17 |
+
filter_out_less_promising_diseases
|
18 |
)
|
19 |
from llm_res import get_short_summary_out_of_json_files, tagging_insights_from_json
|
20 |
import json
|
21 |
import numpy as np
|
22 |
from sentence_transformers import SentenceTransformer
|
23 |
+
import matplotlib
|
24 |
|
25 |
|
26 |
# variables to reveal next steps
|
|
|
73 |
status.write("Getting the similarities among the diseases to filter out less promising ones...")
|
74 |
diseases_uris = [disease["uri"] for disease in diseases_related_to_the_user_text]
|
75 |
similarities = get_similarities_among_diseases_uris(diseases_uris)
|
76 |
+
status.info(f'Obtained similarity information among the diseases by measuring the cosine similarity of their embeddings.')
|
77 |
status.json(similarities, expanded=False)
|
78 |
+
filtered_diseases_uris, df_similarities = filter_out_less_promising_diseases(similarities)
|
79 |
+
# Apply a colormap to the table
|
80 |
+
status.table(df_similarities.style.background_gradient(cmap='viridis', axis=None))
|
81 |
+
status.info(f'Filtered out less promising diseases, keeping {len(filtered_diseases_uris)} diseases.')
|
82 |
+
status.json(filtered_diseases_uris, expanded=False)
|
83 |
status.divider()
|
84 |
# 4. Potentially filter out the diseases that are not similar enough (e.g. similarity < 0.8)
|
85 |
# 5. Augment the set of diseases: add new diseases that are similar to the ones that are already in the set, until we get 10-15 diseases
|
|
|
87 |
augmented_set_of_diseases = augment_the_set_of_diseaces(diseases_uris)
|
88 |
# print(augmented_set_of_diseases)
|
89 |
status.info(f'Augmented set of diseases: {len(augmented_set_of_diseases)} diseases.')
|
90 |
+
status.json(augmented_set_of_diseases, expanded=False)
|
91 |
+
status.divider()
|
92 |
# 6. Query the embeddings of the diseases related to each clinical trial (also in the DB), to get the most similar clinical trials to our set of diseases
|
93 |
status.write("Getting the clinical trials related to the diseases found...")
|
94 |
clinical_trials_related_to_the_diseases = get_clinical_trials_related_to_diseases(
|
|
|
106 |
# 7. Use an LLM to get a summary of the clinical trials, in plain text format.
|
107 |
status.write("Getting a summary of the clinical trials...")
|
108 |
response, stats_dict = get_short_summary_out_of_json_files(json_of_clinical_trials)
|
109 |
+
disease_overview = response
|
|
|
110 |
status.write(f'Response from LLM summarization: {response}')
|
111 |
# 8. Use an LLM to extract numerical data from the clinical trials (e.g. number of patients, number of deaths, etc.). Get summary statistics out of that.
|
112 |
status.write("Getting summary statistics of the clinical trials...")
|
113 |
+
#response = tagging_insights_from_json(json_of_clinical_trials)
|
114 |
+
response = ""
|
115 |
print(f'Response from LLM tagging: {response}')
|
116 |
status.write(f'Response from LLM tagging: {response}')
|
117 |
# 9. Show the results to the user: graph of the diseases chosen, summary of the clinical trials, summary statistics of the clinical trials, and list of the details of the clinical trials considered
|
118 |
status.update(label="Done!", state="complete")
|
119 |
status.balloons()
|
120 |
show_graph = True
|
121 |
+
trials = json_of_clinical_trials
|
122 |
|
123 |
|
124 |
# graph
|
|
|
168 |
# overview
|
169 |
with st.container():
|
170 |
if show_overview:
|
171 |
+
st.write("## Overview of Related Clinical Trials")
|
|
|
172 |
st.write(disease_overview)
|
173 |
time.sleep(2)
|
174 |
show_details = True
|
|
|
178 |
with st.container():
|
179 |
if show_details:
|
180 |
st.write("## Clinical Trials Details")
|
|
|
|
|
|
|
|
|
|
|
|
|
181 |
|
182 |
tab_titles = [f"{trial['protocolSection']['identificationModule']['nctId']}" for trial in trials]
|
183 |
|
llm_res.py
CHANGED
@@ -221,20 +221,23 @@ def process_dictionaty_with_llm_to_generate_response(json_data):
|
|
221 |
return filtered_data
|
222 |
|
223 |
def get_short_summary_out_of_json_files(data_json):
|
224 |
-
prompt_template = """You are an expert
|
|
|
|
|
|
|
225 |
|
226 |
-
|
227 |
-
|
|
|
228 |
|
229 |
-
#
|
|
|
230 |
|
231 |
-
|
232 |
|
233 |
-
|
234 |
|
235 |
-
|
236 |
-
# Task
|
237 |
-
You will be given a text of descriptions of multiple clinical trials realed to similar diseases. Your job is to come up with a short and detailed summary of the descriptions of the clinical trials. Your users are clinical researchers, so you should be technical and specific, including scientific terms in the summary."""
|
238 |
|
239 |
prompt = PromptTemplate.from_template(prompt_template)
|
240 |
|
|
|
221 |
return filtered_data
|
222 |
|
223 |
def get_short_summary_out_of_json_files(data_json):
|
224 |
+
prompt_template = """You are an expert on clinicial trials and their analysis of their reports.
|
225 |
+
|
226 |
+
# Task
|
227 |
+
You will be given a text of descriptions of multiple clinical trials realed to similar diseases. Your job is to come up with a short and detailed summary of the descriptions of the clinical trials. Your users are clinical researchers, so you should be technical and specific, including scientific terms in the summary.
|
228 |
|
229 |
+
{text}"""
|
230 |
+
|
231 |
+
prompt_template = """You are an expert clinician working on the analysis of reports of clinical trials.
|
232 |
|
233 |
+
# Task
|
234 |
+
You will be given a set of descriptions of clinical trials. Your job is to come up with a short summary (100-200 words) of the descriptions of the clinical trials. Your users are clinical researchers who are experts in medicine, so you should be technical and specific, including scientific terms. Always be faithful to the original information written in the reports.
|
235 |
|
236 |
+
To write your summary, you will need to read the following examples, labeled as "Report 1", "Report 2", and so on. Your answer should be a single paragraph (100-200 words) that summarizes the general content of all the reports. Format your answer in Markdown format, **highlighting** the most important concepts, and _italicizing_ the technical concepts extracted from the reports. Be very specific about the details of the clinical trials.
|
237 |
|
238 |
+
{text}
|
239 |
|
240 |
+
General summary:"""
|
|
|
|
|
241 |
|
242 |
prompt = PromptTemplate.from_template(prompt_template)
|
243 |
|
requirements.txt
CHANGED
@@ -11,3 +11,4 @@ sentence_transformers==2.7.0
|
|
11 |
streamlit-agraph
|
12 |
streamlit==1.34.0
|
13 |
langchain-openai==0.1.6
|
|
|
|
11 |
streamlit-agraph
|
12 |
streamlit==1.34.0
|
13 |
langchain-openai==0.1.6
|
14 |
+
matplotlib==3.8.4
|
utils.py
CHANGED
@@ -5,6 +5,7 @@ from sqlalchemy import create_engine, text
|
|
5 |
import requests
|
6 |
from sentence_transformers import SentenceTransformer
|
7 |
import streamlit as st
|
|
|
8 |
|
9 |
username = "demo"
|
10 |
password = "demo"
|
@@ -121,7 +122,12 @@ def get_similarities_among_diseases_uris(
|
|
121 |
"""
|
122 |
result = conn.execute(text(sql))
|
123 |
data = result.fetchall()
|
124 |
-
|
|
|
|
|
|
|
|
|
|
|
125 |
|
126 |
|
127 |
def augment_the_set_of_diseaces(diseases: List[str]) -> str:
|
@@ -169,7 +175,7 @@ def get_diseases_related_to_a_textual_description(
|
|
169 |
result = conn.execute(text(sql))
|
170 |
data = result.fetchall()
|
171 |
|
172 |
-
return [{"uri": row[0], "distance": row[1]} for row in data if row[1] > 0.8]
|
173 |
|
174 |
def get_clinical_trials_related_to_diseases(
|
175 |
diseases: List[str], encoder
|
@@ -191,6 +197,20 @@ def get_clinical_trials_related_to_diseases(
|
|
191 |
|
192 |
return [{"nct_id": row[0], "distance": row[1]} for row in data]
|
193 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
194 |
def to_capitalized_case(string: str) -> str:
|
195 |
string = string.replace("_", " ")
|
196 |
if string.isupper():
|
@@ -206,36 +226,53 @@ def render_trial_details(trial: dict) -> None:
|
|
206 |
official_title = trial["protocolSection"]["identificationModule"]["officialTitle"]
|
207 |
st.write(f"##### {official_title}")
|
208 |
|
209 |
-
|
210 |
-
|
|
|
|
|
|
|
|
|
|
|
211 |
|
212 |
-
status_module = {
|
213 |
-
"Status": to_capitalized_case(trial["protocolSection"]["statusModule"]["overallStatus"]),
|
214 |
-
"Status Date": trial["protocolSection"]["statusModule"]["statusVerifiedDate"],
|
215 |
-
"Has Results": trial["hasResults"]
|
216 |
-
}
|
217 |
st.write("###### Status")
|
218 |
-
|
219 |
-
|
220 |
-
|
221 |
-
|
222 |
-
|
223 |
-
|
224 |
-
|
225 |
-
|
226 |
-
"
|
227 |
-
|
228 |
-
}
|
229 |
st.write("###### Design")
|
230 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
231 |
|
232 |
-
interventions_module = {}
|
233 |
-
for intervention in trial["protocolSection"]["armsInterventionsModule"]["interventions"]:
|
234 |
-
name = intervention["name"]
|
235 |
-
desc = intervention["description"]
|
236 |
-
interventions_module[name] = desc
|
237 |
st.write("###### Interventions")
|
238 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
239 |
|
240 |
if __name__ == "__main__":
|
241 |
username = "demo"
|
|
|
5 |
import requests
|
6 |
from sentence_transformers import SentenceTransformer
|
7 |
import streamlit as st
|
8 |
+
import pandas as pd
|
9 |
|
10 |
username = "demo"
|
11 |
password = "demo"
|
|
|
122 |
"""
|
123 |
result = conn.execute(text(sql))
|
124 |
data = result.fetchall()
|
125 |
+
|
126 |
+
return [{
|
127 |
+
"uri1": row[0].split("/")[-1],
|
128 |
+
"uri2": row[1].split("/")[-1],
|
129 |
+
"distance": float(row[2]),
|
130 |
+
} for row in data]
|
131 |
|
132 |
|
133 |
def augment_the_set_of_diseaces(diseases: List[str]) -> str:
|
|
|
175 |
result = conn.execute(text(sql))
|
176 |
data = result.fetchall()
|
177 |
|
178 |
+
return [{"uri": row[0], "distance": float(row[1])} for row in data if float(row[1]) > 0.8]
|
179 |
|
180 |
def get_clinical_trials_related_to_diseases(
|
181 |
diseases: List[str], encoder
|
|
|
197 |
|
198 |
return [{"nct_id": row[0], "distance": row[1]} for row in data]
|
199 |
|
200 |
+
def filter_out_less_promising_diseases(info_dicts: List[Dict[str, Any]]) -> List[str]:
|
201 |
+
# Find out the score of each disease by averaging the cosine similarity of the embeddings of the diseases that include it as uri1 or uri2
|
202 |
+
df_diseases_similarities = pd.DataFrame(info_dicts)
|
203 |
+
# Use uri1 as the index, and uri2 as the columns. The values are the distances.
|
204 |
+
df_diseases_similarities = df_diseases_similarities.pivot(index="uri1", columns="uri2", values="distance")
|
205 |
+
# Fill the diagonal with 1.0
|
206 |
+
df_diseases_similarities = df_diseases_similarities.fillna(1.0)
|
207 |
+
|
208 |
+
# Filter out the diseases that are 1 standard deviation below the mean
|
209 |
+
mean = df_diseases_similarities.mean().mean()
|
210 |
+
std = df_diseases_similarities.mean().std()
|
211 |
+
filtered_diseases = df_diseases_similarities.mean()[df_diseases_similarities.mean() > mean - std].index.tolist()
|
212 |
+
return filtered_diseases, df_diseases_similarities
|
213 |
+
|
214 |
def to_capitalized_case(string: str) -> str:
|
215 |
string = string.replace("_", " ")
|
216 |
if string.isupper():
|
|
|
226 |
official_title = trial["protocolSection"]["identificationModule"]["officialTitle"]
|
227 |
st.write(f"##### {official_title}")
|
228 |
|
229 |
+
try:
|
230 |
+
st.write(trial["protocolSection"]["descriptionModule"]["briefSummary"])
|
231 |
+
except KeyError:
|
232 |
+
try:
|
233 |
+
st.write(trial["protocolSection"]["descriptionModule"]["detailedDescription"])
|
234 |
+
except KeyError:
|
235 |
+
st.error("No description available.")
|
236 |
|
|
|
|
|
|
|
|
|
|
|
237 |
st.write("###### Status")
|
238 |
+
try:
|
239 |
+
status_module = {
|
240 |
+
"Status": to_capitalized_case(trial["protocolSection"]["statusModule"]["overallStatus"]),
|
241 |
+
"Status Date": trial["protocolSection"]["statusModule"]["statusVerifiedDate"],
|
242 |
+
"Has Results": trial["hasResults"]
|
243 |
+
}
|
244 |
+
st.table(status_module)
|
245 |
+
except KeyError:
|
246 |
+
st.info("No status information available.")
|
247 |
+
|
|
|
248 |
st.write("###### Design")
|
249 |
+
try:
|
250 |
+
design_module = {
|
251 |
+
"Study Type": to_capitalized_case(trial["protocolSection"]["designModule"]["studyType"]),
|
252 |
+
"Phases": list_to_capitalized_case(trial["protocolSection"]["designModule"]["phases"]),
|
253 |
+
"Allocation": to_capitalized_case(trial["protocolSection"]["designModule"]["designInfo"]["allocation"]),
|
254 |
+
"Primary Purpose": to_capitalized_case(trial["protocolSection"]["designModule"]["designInfo"]["primaryPurpose"]),
|
255 |
+
"Participants": trial["protocolSection"]["designModule"]["enrollmentInfo"]["count"],
|
256 |
+
"Masking": to_capitalized_case(trial["protocolSection"]["designModule"]["designInfo"]["maskingInfo"]["masking"]),
|
257 |
+
"Who Masked": list_to_capitalized_case(trial["protocolSection"]["designModule"]["designInfo"]["maskingInfo"]["whoMasked"])
|
258 |
+
}
|
259 |
+
st.table(design_module)
|
260 |
+
except KeyError:
|
261 |
+
st.info("No design information available.")
|
262 |
|
|
|
|
|
|
|
|
|
|
|
263 |
st.write("###### Interventions")
|
264 |
+
try:
|
265 |
+
interventions_module = {}
|
266 |
+
for intervention in trial["protocolSection"]["armsInterventionsModule"]["interventions"]:
|
267 |
+
name = intervention["name"]
|
268 |
+
desc = intervention["description"]
|
269 |
+
interventions_module[name] = desc
|
270 |
+
st.table(interventions_module)
|
271 |
+
except KeyError:
|
272 |
+
st.info("No interventions information available.")
|
273 |
+
|
274 |
+
# Button to go to ClinicalTrials.gov and see the trial. It takes the user to the official page of the trial.
|
275 |
+
st.markdown(f"See more in [ClinicalTrials.gov](https://clinicaltrials.gov/study/{trial['protocolSection']['identificationModule']['nctId']})")
|
276 |
|
277 |
if __name__ == "__main__":
|
278 |
username = "demo"
|