Spaces:
Sleeping
Sleeping
ACMCMC
commited on
Commit
•
4bb7c94
1
Parent(s):
1fb895a
WIP
Browse files- app.py +44 -10
- llm_res.py +49 -36
- utils.py +1 -1
app.py
CHANGED
@@ -28,6 +28,7 @@ show_graph = False
|
|
28 |
show_analyze_status = False
|
29 |
show_overview = False
|
30 |
show_details = False
|
|
|
31 |
|
32 |
# IRIS connection
|
33 |
username = "demo"
|
@@ -66,7 +67,7 @@ with st.container():
|
|
66 |
diseases_related_to_the_user_text = get_diseases_related_to_a_textual_description(
|
67 |
description_input, encoder
|
68 |
)
|
69 |
-
status.info(f'
|
70 |
status.json(diseases_related_to_the_user_text, expanded=False)
|
71 |
status.divider()
|
72 |
# 3. Get the similarities of the embeddings of those diseases (cosine similarity of the embeddings of the nodes of such diseases)
|
@@ -94,26 +95,37 @@ with st.container():
|
|
94 |
clinical_trials_related_to_the_diseases = get_clinical_trials_related_to_diseases(
|
95 |
augmented_set_of_diseases, encoder
|
96 |
)
|
97 |
-
status.info(f'
|
98 |
status.json(clinical_trials_related_to_the_diseases, expanded=False)
|
99 |
status.divider()
|
100 |
status.write("Getting the details of the clinical trials...")
|
101 |
json_of_clinical_trials = get_clinical_records_by_ids(
|
102 |
[trial["nct_id"] for trial in clinical_trials_related_to_the_diseases]
|
103 |
)
|
|
|
104 |
status.json(json_of_clinical_trials, expanded=False)
|
105 |
status.divider()
|
106 |
# 7. Use an LLM to get a summary of the clinical trials, in plain text format.
|
107 |
-
|
108 |
-
|
109 |
-
|
|
|
|
|
|
|
|
|
|
|
110 |
try:
|
111 |
# 8. Use an LLM to extract numerical data from the clinical trials (e.g. number of patients, number of deaths, etc.). Get summary statistics out of that.
|
112 |
status.write("Getting summary statistics of the clinical trials...")
|
113 |
response = tagging_insights_from_json(json_of_clinical_trials)
|
|
|
|
|
|
|
|
|
114 |
print(f'Response from LLM tagging: {response}')
|
115 |
-
status.
|
116 |
except Exception as e:
|
|
|
117 |
print(f'Error while extracting numerical data from the clinical trials: {e}')
|
118 |
status.warning(f'Error while extracting numerical data from the clinical trials. This information will not be shown.')
|
119 |
# 9. Show the results to the user: graph of the diseases chosen, summary of the clinical trials, summary statistics of the clinical trials, and list of the details of the clinical trials considered
|
@@ -170,10 +182,32 @@ $$"""
|
|
170 |
# overview
|
171 |
with st.container():
|
172 |
if show_overview:
|
173 |
-
|
174 |
-
|
175 |
-
|
176 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
177 |
|
178 |
|
179 |
# details
|
|
|
28 |
show_analyze_status = False
|
29 |
show_overview = False
|
30 |
show_details = False
|
31 |
+
show_metrics = False
|
32 |
|
33 |
# IRIS connection
|
34 |
username = "demo"
|
|
|
67 |
diseases_related_to_the_user_text = get_diseases_related_to_a_textual_description(
|
68 |
description_input, encoder
|
69 |
)
|
70 |
+
status.info(f'Selected {len(diseases_related_to_the_user_text)} diseases related to the description you entered.')
|
71 |
status.json(diseases_related_to_the_user_text, expanded=False)
|
72 |
status.divider()
|
73 |
# 3. Get the similarities of the embeddings of those diseases (cosine similarity of the embeddings of the nodes of such diseases)
|
|
|
95 |
clinical_trials_related_to_the_diseases = get_clinical_trials_related_to_diseases(
|
96 |
augmented_set_of_diseases, encoder
|
97 |
)
|
98 |
+
status.info(f'Selected {len(clinical_trials_related_to_the_diseases)} clinical trials related to the diseases.')
|
99 |
status.json(clinical_trials_related_to_the_diseases, expanded=False)
|
100 |
status.divider()
|
101 |
status.write("Getting the details of the clinical trials...")
|
102 |
json_of_clinical_trials = get_clinical_records_by_ids(
|
103 |
[trial["nct_id"] for trial in clinical_trials_related_to_the_diseases]
|
104 |
)
|
105 |
+
status.success(f'Details of the clinical trials obtained.')
|
106 |
status.json(json_of_clinical_trials, expanded=False)
|
107 |
status.divider()
|
108 |
# 7. Use an LLM to get a summary of the clinical trials, in plain text format.
|
109 |
+
try:
|
110 |
+
status.write("Getting a summary of the clinical trials...")
|
111 |
+
response = get_short_summary_out_of_json_files(json_of_clinical_trials)
|
112 |
+
status.success("Summary of the clinical trials obtained.")
|
113 |
+
disease_overview = response
|
114 |
+
except Exception as e:
|
115 |
+
print(f'Error while getting a summary of the clinical trials: {e}')
|
116 |
+
status.warning(f'Error while getting a summary of the clinical trials. This information will not be shown.')
|
117 |
try:
|
118 |
# 8. Use an LLM to extract numerical data from the clinical trials (e.g. number of patients, number of deaths, etc.). Get summary statistics out of that.
|
119 |
status.write("Getting summary statistics of the clinical trials...")
|
120 |
response = tagging_insights_from_json(json_of_clinical_trials)
|
121 |
+
average_minimum_age = response["avg_min_age"]
|
122 |
+
average_maximum_age = response["avg_max_age"]
|
123 |
+
most_common_gender = response['most_common_gender']
|
124 |
+
|
125 |
print(f'Response from LLM tagging: {response}')
|
126 |
+
status.success(f'Summary statistics of the clinical trials obtained.')
|
127 |
except Exception as e:
|
128 |
+
raise e
|
129 |
print(f'Error while extracting numerical data from the clinical trials: {e}')
|
130 |
status.warning(f'Error while extracting numerical data from the clinical trials. This information will not be shown.')
|
131 |
# 9. Show the results to the user: graph of the diseases chosen, summary of the clinical trials, summary statistics of the clinical trials, and list of the details of the clinical trials considered
|
|
|
182 |
# overview
|
183 |
with st.container():
|
184 |
if show_overview:
|
185 |
+
try:
|
186 |
+
st.write("## Overview of Related Clinical Trials")
|
187 |
+
st.write(disease_overview)
|
188 |
+
time.sleep(1)
|
189 |
+
except Exception as e:
|
190 |
+
print(f'Error while showing the overview of the clinical trials: {e}')
|
191 |
+
finally:
|
192 |
+
show_metrics = True
|
193 |
+
|
194 |
+
|
195 |
+
with st.container():
|
196 |
+
if show_metrics:
|
197 |
+
try:
|
198 |
+
st.write("## Metrics of the Clinical Trials")
|
199 |
+
col1, col2, col3 = st.columns(3)
|
200 |
+
with col1:
|
201 |
+
st.metric("Average Minimum Age", average_minimum_age)
|
202 |
+
with col2:
|
203 |
+
st.metric("Average Maximum Age", average_maximum_age)
|
204 |
+
with col3:
|
205 |
+
st.metric("Most Common Gender", most_common_gender)
|
206 |
+
time.sleep(2)
|
207 |
+
except Exception as e:
|
208 |
+
print(f'Error while showing the metrics: {e}')
|
209 |
+
finally:
|
210 |
+
show_details = True
|
211 |
|
212 |
|
213 |
# details
|
llm_res.py
CHANGED
@@ -24,6 +24,7 @@ from langchain.chains.llm import LLMChain
|
|
24 |
from langchain_core.prompts import PromptTemplate
|
25 |
from collections import Counter
|
26 |
import statistics
|
|
|
27 |
|
28 |
load_dotenv()
|
29 |
|
@@ -134,11 +135,12 @@ def get_clinical_records_by_ids(clinical_record_ids: List[str]) -> List[Dict[str
|
|
134 |
# "eligibility": eligibility,
|
135 |
# }
|
136 |
# filtered_data.append(filtered_item)
|
137 |
-
|
138 |
# return filtered_data
|
139 |
# # for ele in filtered_data:
|
140 |
# # print(ele)
|
141 |
|
|
|
142 |
def process_dictionaty_with_llm_to_generate_response(json_data):
|
143 |
# processed_data = process_json_data_for_llm(json_data)
|
144 |
# res = tagging_chain.invoke({"input": processed_data})
|
@@ -217,9 +219,10 @@ def process_dictionaty_with_llm_to_generate_response(json_data):
|
|
217 |
"eligibility": eligibility,
|
218 |
}
|
219 |
filtered_data.append(filtered_item)
|
220 |
-
|
221 |
return filtered_data
|
222 |
|
|
|
223 |
def get_short_summary_out_of_json_files(data_json):
|
224 |
prompt_template = """You are an expert on clinicial trials and their analysis of their reports.
|
225 |
|
@@ -272,29 +275,36 @@ General summary:"""
|
|
272 |
|
273 |
return result
|
274 |
|
|
|
275 |
def analyze_data(data):
|
276 |
-
|
277 |
-
|
278 |
-
|
|
|
279 |
# primary_timeframe= [int(age.split()[0]) for age in data['[primary_outcome]'] if age]
|
280 |
-
|
281 |
# Calculate average minimum and maximum ages
|
282 |
avg_min_age = statistics.mean(min_ages) if min_ages else None
|
283 |
avg_max_age = statistics.mean(max_ages) if max_ages else None
|
284 |
-
|
285 |
# Find most common gender
|
286 |
-
gender_counter = Counter(data[
|
287 |
most_common_gender = gender_counter.most_common(1)[0][0]
|
288 |
-
|
289 |
# Flatten keywords list and find common keywords
|
290 |
-
keywords = [keyword for sublist in data[
|
291 |
-
common_keywords = [word for word, count in Counter(keywords).most_common()]
|
292 |
-
|
293 |
-
return
|
|
|
|
|
|
|
|
|
|
|
294 |
|
295 |
def tagging_insights_from_json(data_json):
|
296 |
-
processed_json= process_dictionaty_with_llm_to_generate_response(data_json)
|
297 |
-
|
298 |
tagging_prompt = ChatPromptTemplate.from_template(
|
299 |
"""
|
300 |
You are an expert on clinicial trials and analysis of their reports.
|
@@ -307,6 +317,7 @@ def tagging_insights_from_json(data_json):
|
|
307 |
{input}
|
308 |
"""
|
309 |
)
|
|
|
310 |
class Classification(BaseModel):
|
311 |
# description: str = Field(
|
312 |
# description="text description grouping all the clinical trials using briefDescription and detailedDescription keys"
|
@@ -317,25 +328,25 @@ def tagging_insights_from_json(data_json):
|
|
317 |
# status: list = Field(
|
318 |
# description="Extract the status of all the clinical trials"
|
319 |
# )
|
320 |
-
#keywords: list = Field(
|
321 |
# description="Extract the most relevant keywords for each clinical trials"
|
322 |
-
#)
|
323 |
# interventions: list = Field(
|
324 |
# description="describe the interventions for each clinical trial using title, name and description"
|
325 |
# )
|
326 |
-
#primary_outcomes: list = Field(
|
327 |
# description="get the timeframe of each clinical trial"
|
328 |
-
#)
|
329 |
-
#secondary_outcomes: list= Field(description= "get the secondary outcomes of each clinical trial")
|
330 |
-
#eligibility: list = Field(
|
331 |
# description="get the timeframe of each clinical trial"
|
332 |
-
#)
|
333 |
# healthy_volunteers: list= Field(description= "determine whether the clinical trial requires healthy volunteers")
|
334 |
minimum_age: list = Field(
|
335 |
-
|
336 |
)
|
337 |
maximum_age: list = Field(
|
338 |
-
|
339 |
)
|
340 |
gender: list = Field(description="get the gender from each experiment")
|
341 |
|
@@ -343,15 +354,15 @@ def tagging_insights_from_json(data_json):
|
|
343 |
return {
|
344 |
# "project_title": self.project_title,
|
345 |
# "status": self.status,
|
346 |
-
#"keywords": self.keywords,
|
347 |
# "interventions": self.interventions,
|
348 |
-
#"primary_outcomes": self.primary_outcomes,
|
349 |
-
#"secondary_outcomes": self.secondary_outcomes,
|
350 |
# "eligibility": self.eligibility,
|
351 |
# "healthy_volunteers": self.healthy_volunteers,
|
352 |
"minimum_age": self.minimum_age,
|
353 |
"maximum_age": self.maximum_age,
|
354 |
-
"gender": self.gender
|
355 |
}
|
356 |
|
357 |
# LLM
|
@@ -365,18 +376,20 @@ def tagging_insights_from_json(data_json):
|
|
365 |
|
366 |
tagging_chain = tagging_prompt | llm
|
367 |
|
368 |
-
res= tagging_chain.invoke({"input": processed_json})
|
369 |
-
|
370 |
|
371 |
-
|
|
|
|
|
372 |
|
373 |
-
#stats_dict= {'Average Minimum age': avg_min_age,
|
374 |
# 'Average Maximum age': avg_max_age,
|
375 |
# 'Most common gender undergoing the trials': most_common_gender,
|
376 |
# 'common keywords found in the trials': common_keywords}
|
377 |
-
|
378 |
-
print(f"Result_tagging: {
|
379 |
-
return
|
380 |
|
381 |
|
382 |
# clinical_record_info = get_clinical_records_by_ids(['NCT00841061', 'NCT03035123', 'NCT02272751', 'NCT03035123', 'NCT03055377'])
|
@@ -386,4 +399,4 @@ def tagging_insights_from_json(data_json):
|
|
386 |
# json.dump(clinical_record_info, f, indent=4)
|
387 |
|
388 |
|
389 |
-
# tagging_chain = tagging_insights_from_json(json_data)
|
|
|
24 |
from langchain_core.prompts import PromptTemplate
|
25 |
from collections import Counter
|
26 |
import statistics
|
27 |
+
import regex as re
|
28 |
|
29 |
load_dotenv()
|
30 |
|
|
|
135 |
# "eligibility": eligibility,
|
136 |
# }
|
137 |
# filtered_data.append(filtered_item)
|
138 |
+
|
139 |
# return filtered_data
|
140 |
# # for ele in filtered_data:
|
141 |
# # print(ele)
|
142 |
|
143 |
+
|
144 |
def process_dictionaty_with_llm_to_generate_response(json_data):
|
145 |
# processed_data = process_json_data_for_llm(json_data)
|
146 |
# res = tagging_chain.invoke({"input": processed_data})
|
|
|
219 |
"eligibility": eligibility,
|
220 |
}
|
221 |
filtered_data.append(filtered_item)
|
222 |
+
|
223 |
return filtered_data
|
224 |
|
225 |
+
|
226 |
def get_short_summary_out_of_json_files(data_json):
|
227 |
prompt_template = """You are an expert on clinicial trials and their analysis of their reports.
|
228 |
|
|
|
275 |
|
276 |
return result
|
277 |
|
278 |
+
|
279 |
def analyze_data(data):
|
280 |
+
print(f"Data: {data}")
|
281 |
+
# Extract minimum and maximum ages: Turn ['18 Years', '20 Years'] into [18, 20]
|
282 |
+
min_ages = [int(re.search(r"\d+", age).group()) for age in data["minimum_age"] if age]
|
283 |
+
max_ages = [int(re.search(r"\d+", age).group()) for age in data["maximum_age"] if age]
|
284 |
# primary_timeframe= [int(age.split()[0]) for age in data['[primary_outcome]'] if age]
|
285 |
+
|
286 |
# Calculate average minimum and maximum ages
|
287 |
avg_min_age = statistics.mean(min_ages) if min_ages else None
|
288 |
avg_max_age = statistics.mean(max_ages) if max_ages else None
|
289 |
+
|
290 |
# Find most common gender
|
291 |
+
gender_counter = Counter(data["gender"])
|
292 |
most_common_gender = gender_counter.most_common(1)[0][0]
|
293 |
+
|
294 |
# Flatten keywords list and find common keywords
|
295 |
+
#keywords = [keyword for sublist in data["keywords"] for keyword in sublist]
|
296 |
+
#common_keywords = [word for word, count in Counter(keywords).most_common()]
|
297 |
+
|
298 |
+
return {
|
299 |
+
"avg_min_age": avg_min_age,
|
300 |
+
"avg_max_age": avg_max_age,
|
301 |
+
"most_common_gender": most_common_gender
|
302 |
+
}
|
303 |
+
|
304 |
|
305 |
def tagging_insights_from_json(data_json):
|
306 |
+
processed_json = process_dictionaty_with_llm_to_generate_response(data_json)
|
307 |
+
|
308 |
tagging_prompt = ChatPromptTemplate.from_template(
|
309 |
"""
|
310 |
You are an expert on clinicial trials and analysis of their reports.
|
|
|
317 |
{input}
|
318 |
"""
|
319 |
)
|
320 |
+
|
321 |
class Classification(BaseModel):
|
322 |
# description: str = Field(
|
323 |
# description="text description grouping all the clinical trials using briefDescription and detailedDescription keys"
|
|
|
328 |
# status: list = Field(
|
329 |
# description="Extract the status of all the clinical trials"
|
330 |
# )
|
331 |
+
# keywords: list = Field(
|
332 |
# description="Extract the most relevant keywords for each clinical trials"
|
333 |
+
# )
|
334 |
# interventions: list = Field(
|
335 |
# description="describe the interventions for each clinical trial using title, name and description"
|
336 |
# )
|
337 |
+
# primary_outcomes: list = Field(
|
338 |
# description="get the timeframe of each clinical trial"
|
339 |
+
# )
|
340 |
+
# secondary_outcomes: list= Field(description= "get the secondary outcomes of each clinical trial")
|
341 |
+
# eligibility: list = Field(
|
342 |
# description="get the timeframe of each clinical trial"
|
343 |
+
# )
|
344 |
# healthy_volunteers: list= Field(description= "determine whether the clinical trial requires healthy volunteers")
|
345 |
minimum_age: list = Field(
|
346 |
+
description="get the minimum age from each experiment"
|
347 |
)
|
348 |
maximum_age: list = Field(
|
349 |
+
description="get the maximum age from each experiment"
|
350 |
)
|
351 |
gender: list = Field(description="get the gender from each experiment")
|
352 |
|
|
|
354 |
return {
|
355 |
# "project_title": self.project_title,
|
356 |
# "status": self.status,
|
357 |
+
# "keywords": self.keywords,
|
358 |
# "interventions": self.interventions,
|
359 |
+
# "primary_outcomes": self.primary_outcomes,
|
360 |
+
# "secondary_outcomes": self.secondary_outcomes,
|
361 |
# "eligibility": self.eligibility,
|
362 |
# "healthy_volunteers": self.healthy_volunteers,
|
363 |
"minimum_age": self.minimum_age,
|
364 |
"maximum_age": self.maximum_age,
|
365 |
+
"gender": self.gender,
|
366 |
}
|
367 |
|
368 |
# LLM
|
|
|
376 |
|
377 |
tagging_chain = tagging_prompt | llm
|
378 |
|
379 |
+
res = tagging_chain.invoke({"input": processed_json})
|
380 |
+
unprocessed_results_dict = res.get_dict()
|
381 |
|
382 |
+
results_dict = analyze_data(
|
383 |
+
unprocessed_results_dict
|
384 |
+
)
|
385 |
|
386 |
+
# stats_dict= {'Average Minimum age': avg_min_age,
|
387 |
# 'Average Maximum age': avg_max_age,
|
388 |
# 'Most common gender undergoing the trials': most_common_gender,
|
389 |
# 'common keywords found in the trials': common_keywords}
|
390 |
+
|
391 |
+
print(f"Result_tagging: {results_dict}")
|
392 |
+
return results_dict
|
393 |
|
394 |
|
395 |
# clinical_record_info = get_clinical_records_by_ids(['NCT00841061', 'NCT03035123', 'NCT02272751', 'NCT03035123', 'NCT03055377'])
|
|
|
399 |
# json.dump(clinical_record_info, f, indent=4)
|
400 |
|
401 |
|
402 |
+
# tagging_chain = tagging_insights_from_json(json_data)
|
utils.py
CHANGED
@@ -189,7 +189,7 @@ def get_clinical_trials_related_to_diseases(
|
|
189 |
with engine.connect() as conn:
|
190 |
with conn.begin():
|
191 |
sql = f"""
|
192 |
-
SELECT TOP
|
193 |
FROM Test.ClinicalTrials d
|
194 |
ORDER BY distance DESC
|
195 |
"""
|
|
|
189 |
with engine.connect() as conn:
|
190 |
with conn.begin():
|
191 |
sql = f"""
|
192 |
+
SELECT TOP 15 d.nct_id, VECTOR_COSINE(d.embedding, TO_VECTOR('{string_representation}', DOUBLE)) AS distance
|
193 |
FROM Test.ClinicalTrials d
|
194 |
ORDER BY distance DESC
|
195 |
"""
|