ACMCMC commited on
Commit
e7d7b51
1 Parent(s): 47c6369
Files changed (4) hide show
  1. app.py +16 -13
  2. llm_res.py +12 -9
  3. requirements.txt +1 -0
  4. utils.py +64 -27
app.py CHANGED
@@ -13,12 +13,14 @@ from utils import (
13
  augment_the_set_of_diseaces,
14
  get_clinical_trials_related_to_diseases,
15
  get_clinical_records_by_ids,
16
- render_trial_details
 
17
  )
18
  from llm_res import get_short_summary_out_of_json_files, tagging_insights_from_json
19
  import json
20
  import numpy as np
21
  from sentence_transformers import SentenceTransformer
 
22
 
23
 
24
  # variables to reveal next steps
@@ -71,8 +73,13 @@ with st.container():
71
  status.write("Getting the similarities among the diseases to filter out less promising ones...")
72
  diseases_uris = [disease["uri"] for disease in diseases_related_to_the_user_text]
73
  similarities = get_similarities_among_diseases_uris(diseases_uris)
74
- status.info(f'Obtained similarity information among the diseases by measuring the cosine similarity of their embeddings. Using the similarity information to filter out less promising diseases.')
75
  status.json(similarities, expanded=False)
 
 
 
 
 
76
  status.divider()
77
  # 4. Potentially filter out the diseases that are not similar enough (e.g. similarity < 0.8)
78
  # 5. Augment the set of diseases: add new diseases that are similar to the ones that are already in the set, until we get 10-15 diseases
@@ -80,6 +87,8 @@ with st.container():
80
  augmented_set_of_diseases = augment_the_set_of_diseaces(diseases_uris)
81
  # print(augmented_set_of_diseases)
82
  status.info(f'Augmented set of diseases: {len(augmented_set_of_diseases)} diseases.')
 
 
83
  # 6. Query the embeddings of the diseases related to each clinical trial (also in the DB), to get the most similar clinical trials to our set of diseases
84
  status.write("Getting the clinical trials related to the diseases found...")
85
  clinical_trials_related_to_the_diseases = get_clinical_trials_related_to_diseases(
@@ -97,18 +106,19 @@ with st.container():
97
  # 7. Use an LLM to get a summary of the clinical trials, in plain text format.
98
  status.write("Getting a summary of the clinical trials...")
99
  response, stats_dict = get_short_summary_out_of_json_files(json_of_clinical_trials)
100
- print(f'Response from LLM summarization: {response}')
101
- print(f'basic_stats_dict:{stats_dict}')
102
  status.write(f'Response from LLM summarization: {response}')
103
  # 8. Use an LLM to extract numerical data from the clinical trials (e.g. number of patients, number of deaths, etc.). Get summary statistics out of that.
104
  status.write("Getting summary statistics of the clinical trials...")
105
- response = tagging_insights_from_json(json_of_clinical_trials)
 
106
  print(f'Response from LLM tagging: {response}')
107
  status.write(f'Response from LLM tagging: {response}')
108
  # 9. Show the results to the user: graph of the diseases chosen, summary of the clinical trials, summary statistics of the clinical trials, and list of the details of the clinical trials considered
109
  status.update(label="Done!", state="complete")
110
  status.balloons()
111
  show_graph = True
 
112
 
113
 
114
  # graph
@@ -158,8 +168,7 @@ $$"""
158
  # overview
159
  with st.container():
160
  if show_overview:
161
- st.write("## Disease Overview")
162
- disease_overview = ":red[lorem ipsum]" # TODO
163
  st.write(disease_overview)
164
  time.sleep(2)
165
  show_details = True
@@ -169,12 +178,6 @@ with st.container():
169
  with st.container():
170
  if show_details:
171
  st.write("## Clinical Trials Details")
172
- trials = []
173
- # TODO replace mock data
174
- with open("mock_trial.json") as f:
175
- d = json.load(f)
176
- for i in range(0, 8):
177
- trials.append(d)
178
 
179
  tab_titles = [f"{trial['protocolSection']['identificationModule']['nctId']}" for trial in trials]
180
 
 
13
  augment_the_set_of_diseaces,
14
  get_clinical_trials_related_to_diseases,
15
  get_clinical_records_by_ids,
16
+ render_trial_details,
17
+ filter_out_less_promising_diseases
18
  )
19
  from llm_res import get_short_summary_out_of_json_files, tagging_insights_from_json
20
  import json
21
  import numpy as np
22
  from sentence_transformers import SentenceTransformer
23
+ import matplotlib
24
 
25
 
26
  # variables to reveal next steps
 
73
  status.write("Getting the similarities among the diseases to filter out less promising ones...")
74
  diseases_uris = [disease["uri"] for disease in diseases_related_to_the_user_text]
75
  similarities = get_similarities_among_diseases_uris(diseases_uris)
76
+ status.info(f'Obtained similarity information among the diseases by measuring the cosine similarity of their embeddings.')
77
  status.json(similarities, expanded=False)
78
+ filtered_diseases_uris, df_similarities = filter_out_less_promising_diseases(similarities)
79
+ # Apply a colormap to the table
80
+ status.table(df_similarities.style.background_gradient(cmap='viridis', axis=None))
81
+ status.info(f'Filtered out less promising diseases, keeping {len(filtered_diseases_uris)} diseases.')
82
+ status.json(filtered_diseases_uris, expanded=False)
83
  status.divider()
84
  # 4. Potentially filter out the diseases that are not similar enough (e.g. similarity < 0.8)
85
  # 5. Augment the set of diseases: add new diseases that are similar to the ones that are already in the set, until we get 10-15 diseases
 
87
  augmented_set_of_diseases = augment_the_set_of_diseaces(diseases_uris)
88
  # print(augmented_set_of_diseases)
89
  status.info(f'Augmented set of diseases: {len(augmented_set_of_diseases)} diseases.')
90
+ status.json(augmented_set_of_diseases, expanded=False)
91
+ status.divider()
92
  # 6. Query the embeddings of the diseases related to each clinical trial (also in the DB), to get the most similar clinical trials to our set of diseases
93
  status.write("Getting the clinical trials related to the diseases found...")
94
  clinical_trials_related_to_the_diseases = get_clinical_trials_related_to_diseases(
 
106
  # 7. Use an LLM to get a summary of the clinical trials, in plain text format.
107
  status.write("Getting a summary of the clinical trials...")
108
  response, stats_dict = get_short_summary_out_of_json_files(json_of_clinical_trials)
109
+ disease_overview = response
 
110
  status.write(f'Response from LLM summarization: {response}')
111
  # 8. Use an LLM to extract numerical data from the clinical trials (e.g. number of patients, number of deaths, etc.). Get summary statistics out of that.
112
  status.write("Getting summary statistics of the clinical trials...")
113
+ #response = tagging_insights_from_json(json_of_clinical_trials)
114
+ response = ""
115
  print(f'Response from LLM tagging: {response}')
116
  status.write(f'Response from LLM tagging: {response}')
117
  # 9. Show the results to the user: graph of the diseases chosen, summary of the clinical trials, summary statistics of the clinical trials, and list of the details of the clinical trials considered
118
  status.update(label="Done!", state="complete")
119
  status.balloons()
120
  show_graph = True
121
+ trials = json_of_clinical_trials
122
 
123
 
124
  # graph
 
168
  # overview
169
  with st.container():
170
  if show_overview:
171
+ st.write("## Overview of Related Clinical Trials")
 
172
  st.write(disease_overview)
173
  time.sleep(2)
174
  show_details = True
 
178
  with st.container():
179
  if show_details:
180
  st.write("## Clinical Trials Details")
 
 
 
 
 
 
181
 
182
  tab_titles = [f"{trial['protocolSection']['identificationModule']['nctId']}" for trial in trials]
183
 
llm_res.py CHANGED
@@ -221,20 +221,23 @@ def process_dictionaty_with_llm_to_generate_response(json_data):
221
  return filtered_data
222
 
223
  def get_short_summary_out_of_json_files(data_json):
224
- prompt_template = """You are an expert clinician working on the analysis of reports of clinical trials.
 
 
 
225
 
226
- # # Task
227
- # You will be given a set of descriptions of clinical trials. Your job is to come up with a short summary (100-200 words) of the descriptions of the clinical trials. Your users are clinical researchers who are experts in medicine, so you should be technical and specific, including scientific terms. Always be faithful to the original information written in the reports.
 
228
 
229
- # To write your summary, you will need to read the following examples, labeled as "Report 1", "Report 2", and so on. Your answer should be a single paragraph (100-200 words) that summarizes the general content of all the reports.
 
230
 
231
- # {text}
232
 
233
- # General summary:"""
234
 
235
- prompt_template = """ You are an expert on clinicial trials and their analysis of their reports.
236
- # Task
237
- You will be given a text of descriptions of multiple clinical trials realed to similar diseases. Your job is to come up with a short and detailed summary of the descriptions of the clinical trials. Your users are clinical researchers, so you should be technical and specific, including scientific terms in the summary."""
238
 
239
  prompt = PromptTemplate.from_template(prompt_template)
240
 
 
221
  return filtered_data
222
 
223
  def get_short_summary_out_of_json_files(data_json):
224
+ prompt_template = """You are an expert on clinicial trials and their analysis of their reports.
225
+
226
+ # Task
227
+ You will be given a text of descriptions of multiple clinical trials realed to similar diseases. Your job is to come up with a short and detailed summary of the descriptions of the clinical trials. Your users are clinical researchers, so you should be technical and specific, including scientific terms in the summary.
228
 
229
+ {text}"""
230
+
231
+ prompt_template = """You are an expert clinician working on the analysis of reports of clinical trials.
232
 
233
+ # Task
234
+ You will be given a set of descriptions of clinical trials. Your job is to come up with a short summary (100-200 words) of the descriptions of the clinical trials. Your users are clinical researchers who are experts in medicine, so you should be technical and specific, including scientific terms. Always be faithful to the original information written in the reports.
235
 
236
+ To write your summary, you will need to read the following examples, labeled as "Report 1", "Report 2", and so on. Your answer should be a single paragraph (100-200 words) that summarizes the general content of all the reports. Format your answer in Markdown format, **highlighting** the most important concepts, and _italicizing_ the technical concepts extracted from the reports. Be very specific about the details of the clinical trials.
237
 
238
+ {text}
239
 
240
+ General summary:"""
 
 
241
 
242
  prompt = PromptTemplate.from_template(prompt_template)
243
 
requirements.txt CHANGED
@@ -11,3 +11,4 @@ sentence_transformers==2.7.0
11
  streamlit-agraph
12
  streamlit==1.34.0
13
  langchain-openai==0.1.6
 
 
11
  streamlit-agraph
12
  streamlit==1.34.0
13
  langchain-openai==0.1.6
14
+ matplotlib==3.8.4
utils.py CHANGED
@@ -5,6 +5,7 @@ from sqlalchemy import create_engine, text
5
  import requests
6
  from sentence_transformers import SentenceTransformer
7
  import streamlit as st
 
8
 
9
  username = "demo"
10
  password = "demo"
@@ -121,7 +122,12 @@ def get_similarities_among_diseases_uris(
121
  """
122
  result = conn.execute(text(sql))
123
  data = result.fetchall()
124
- return data
 
 
 
 
 
125
 
126
 
127
  def augment_the_set_of_diseaces(diseases: List[str]) -> str:
@@ -169,7 +175,7 @@ def get_diseases_related_to_a_textual_description(
169
  result = conn.execute(text(sql))
170
  data = result.fetchall()
171
 
172
- return [{"uri": row[0], "distance": row[1]} for row in data if row[1] > 0.8]
173
 
174
  def get_clinical_trials_related_to_diseases(
175
  diseases: List[str], encoder
@@ -191,6 +197,20 @@ def get_clinical_trials_related_to_diseases(
191
 
192
  return [{"nct_id": row[0], "distance": row[1]} for row in data]
193
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
194
  def to_capitalized_case(string: str) -> str:
195
  string = string.replace("_", " ")
196
  if string.isupper():
@@ -206,36 +226,53 @@ def render_trial_details(trial: dict) -> None:
206
  official_title = trial["protocolSection"]["identificationModule"]["officialTitle"]
207
  st.write(f"##### {official_title}")
208
 
209
- brief_summary = trial["protocolSection"]["descriptionModule"]["briefSummary"]
210
- st.write(brief_summary)
 
 
 
 
 
211
 
212
- status_module = {
213
- "Status": to_capitalized_case(trial["protocolSection"]["statusModule"]["overallStatus"]),
214
- "Status Date": trial["protocolSection"]["statusModule"]["statusVerifiedDate"],
215
- "Has Results": trial["hasResults"]
216
- }
217
  st.write("###### Status")
218
- st.table(status_module)
219
-
220
- design_module = {
221
- "Study Type": to_capitalized_case(trial["protocolSection"]["designModule"]["studyType"]),
222
- "Phases": list_to_capitalized_case(trial["protocolSection"]["designModule"]["phases"]),
223
- "Allocation": to_capitalized_case(trial["protocolSection"]["designModule"]["designInfo"]["allocation"]),
224
- "Primary Purpose": to_capitalized_case(trial["protocolSection"]["designModule"]["designInfo"]["primaryPurpose"]),
225
- "Participants": trial["protocolSection"]["designModule"]["enrollmentInfo"]["count"],
226
- "Masking": to_capitalized_case(trial["protocolSection"]["designModule"]["designInfo"]["maskingInfo"]["masking"]),
227
- "Who Masked": list_to_capitalized_case(trial["protocolSection"]["designModule"]["designInfo"]["maskingInfo"]["whoMasked"])
228
- }
229
  st.write("###### Design")
230
- st.table(design_module)
 
 
 
 
 
 
 
 
 
 
 
 
231
 
232
- interventions_module = {}
233
- for intervention in trial["protocolSection"]["armsInterventionsModule"]["interventions"]:
234
- name = intervention["name"]
235
- desc = intervention["description"]
236
- interventions_module[name] = desc
237
  st.write("###### Interventions")
238
- st.table(interventions_module)
 
 
 
 
 
 
 
 
 
 
 
239
 
240
  if __name__ == "__main__":
241
  username = "demo"
 
5
  import requests
6
  from sentence_transformers import SentenceTransformer
7
  import streamlit as st
8
+ import pandas as pd
9
 
10
  username = "demo"
11
  password = "demo"
 
122
  """
123
  result = conn.execute(text(sql))
124
  data = result.fetchall()
125
+
126
+ return [{
127
+ "uri1": row[0].split("/")[-1],
128
+ "uri2": row[1].split("/")[-1],
129
+ "distance": float(row[2]),
130
+ } for row in data]
131
 
132
 
133
  def augment_the_set_of_diseaces(diseases: List[str]) -> str:
 
175
  result = conn.execute(text(sql))
176
  data = result.fetchall()
177
 
178
+ return [{"uri": row[0], "distance": float(row[1])} for row in data if float(row[1]) > 0.8]
179
 
180
  def get_clinical_trials_related_to_diseases(
181
  diseases: List[str], encoder
 
197
 
198
  return [{"nct_id": row[0], "distance": row[1]} for row in data]
199
 
200
+ def filter_out_less_promising_diseases(info_dicts: List[Dict[str, Any]]) -> List[str]:
201
+ # Find out the score of each disease by averaging the cosine similarity of the embeddings of the diseases that include it as uri1 or uri2
202
+ df_diseases_similarities = pd.DataFrame(info_dicts)
203
+ # Use uri1 as the index, and uri2 as the columns. The values are the distances.
204
+ df_diseases_similarities = df_diseases_similarities.pivot(index="uri1", columns="uri2", values="distance")
205
+ # Fill the diagonal with 1.0
206
+ df_diseases_similarities = df_diseases_similarities.fillna(1.0)
207
+
208
+ # Filter out the diseases that are 1 standard deviation below the mean
209
+ mean = df_diseases_similarities.mean().mean()
210
+ std = df_diseases_similarities.mean().std()
211
+ filtered_diseases = df_diseases_similarities.mean()[df_diseases_similarities.mean() > mean - std].index.tolist()
212
+ return filtered_diseases, df_diseases_similarities
213
+
214
  def to_capitalized_case(string: str) -> str:
215
  string = string.replace("_", " ")
216
  if string.isupper():
 
226
  official_title = trial["protocolSection"]["identificationModule"]["officialTitle"]
227
  st.write(f"##### {official_title}")
228
 
229
+ try:
230
+ st.write(trial["protocolSection"]["descriptionModule"]["briefSummary"])
231
+ except KeyError:
232
+ try:
233
+ st.write(trial["protocolSection"]["descriptionModule"]["detailedDescription"])
234
+ except KeyError:
235
+ st.error("No description available.")
236
 
 
 
 
 
 
237
  st.write("###### Status")
238
+ try:
239
+ status_module = {
240
+ "Status": to_capitalized_case(trial["protocolSection"]["statusModule"]["overallStatus"]),
241
+ "Status Date": trial["protocolSection"]["statusModule"]["statusVerifiedDate"],
242
+ "Has Results": trial["hasResults"]
243
+ }
244
+ st.table(status_module)
245
+ except KeyError:
246
+ st.info("No status information available.")
247
+
 
248
  st.write("###### Design")
249
+ try:
250
+ design_module = {
251
+ "Study Type": to_capitalized_case(trial["protocolSection"]["designModule"]["studyType"]),
252
+ "Phases": list_to_capitalized_case(trial["protocolSection"]["designModule"]["phases"]),
253
+ "Allocation": to_capitalized_case(trial["protocolSection"]["designModule"]["designInfo"]["allocation"]),
254
+ "Primary Purpose": to_capitalized_case(trial["protocolSection"]["designModule"]["designInfo"]["primaryPurpose"]),
255
+ "Participants": trial["protocolSection"]["designModule"]["enrollmentInfo"]["count"],
256
+ "Masking": to_capitalized_case(trial["protocolSection"]["designModule"]["designInfo"]["maskingInfo"]["masking"]),
257
+ "Who Masked": list_to_capitalized_case(trial["protocolSection"]["designModule"]["designInfo"]["maskingInfo"]["whoMasked"])
258
+ }
259
+ st.table(design_module)
260
+ except KeyError:
261
+ st.info("No design information available.")
262
 
 
 
 
 
 
263
  st.write("###### Interventions")
264
+ try:
265
+ interventions_module = {}
266
+ for intervention in trial["protocolSection"]["armsInterventionsModule"]["interventions"]:
267
+ name = intervention["name"]
268
+ desc = intervention["description"]
269
+ interventions_module[name] = desc
270
+ st.table(interventions_module)
271
+ except KeyError:
272
+ st.info("No interventions information available.")
273
+
274
+ # Button to go to ClinicalTrials.gov and see the trial. It takes the user to the official page of the trial.
275
+ st.markdown(f"See more in [ClinicalTrials.gov](https://clinicaltrials.gov/study/{trial['protocolSection']['identificationModule']['nctId']})")
276
 
277
  if __name__ == "__main__":
278
  username = "demo"