Spaces:
Runtime error
Runtime error
# language default packages | |
from datetime import datetime | |
# external packages | |
import gradio as gr | |
import asyncio | |
# from langchain.llms import OpenAI | |
# from langchain.prompts import PromptTemplate | |
# from langchain.chains import LLMChain | |
# internal packages | |
from chains import * | |
from cloud_db import * | |
from cloud_storage import * | |
from supplier import * | |
from utility import list_dict_to_dict | |
# get prompts, terms, outputs from the cloud | |
def init_app_data(): | |
''' | |
A function to initialize the application data from the cloud backend. | |
All the cloud data was saved in the app_data dictionary. | |
Parameters | |
---------- | |
None | |
Returns | |
------- | |
None | |
''' | |
app_data["prompts"] = list_dict_to_dict(get_table("prompts"),key="name") | |
app_data["terms"] = get_table("terms") | |
app_data["articles"] = list_dict_to_dict(get_table("articles"),key="name") | |
app_data["summary"] = list_dict_to_dict(get_table("summary"),key="term") | |
app_data["devices"] = list_dict_to_dict(get_table("devices"),key="device_name") | |
# with open(".data/instruction_agg_performance.json","r") as f: | |
# prompts_agg_json = json.load(f) | |
app_data["prompts_agg"] = list_dict_to_dict(get_table("prompts_agg"),key="assessment") | |
def get_ifu(device_name="TranscendTM NanoTec™ Interbody System"): | |
''' | |
This function get the IFU from the cloud S3''' | |
ifu = app_data["devices"][device_name] | |
text = f"{ifu['contraindications']}\n{ifu['indications']}\n{ifu['intended_use']}" | |
return text | |
def get_existing_article( | |
article_name, | |
): | |
''' | |
get_existing_article function receive the article name and return the article object | |
Parameters | |
---------- | |
article_name : str | |
name of the article | |
Returns | |
------- | |
dict | |
article object | |
''' | |
article = app_data["articles"][article_name] | |
app_data["current_article"] = article | |
return create_overview(article), create_detail_views(article) | |
def process_study( # need revision | |
domain, | |
device_ifu, | |
study_file_obj, | |
study_content, | |
): | |
if study_file_obj: | |
article = add_article(domain,study_file_obj) | |
elif study_content: | |
article = add_article(domain,study_content,file_object=False) | |
else: | |
return "No file or content provided","No file or content provided","No file or content provided" | |
# update the common article segment from its existing attributes. | |
update_article_segment(article,device_ifu) | |
# perform pathway logic and content extraction | |
process_prompts(article=article) | |
# perform a post process for perfFUTables | |
post_process(article) | |
# set the current article to the completed article object | |
app_data["current_article"] = article | |
app_data["articles"][article["name"]] = article | |
# update the article to the cloud | |
try: | |
update_article(article) | |
except Exception as e: | |
print(e) | |
# return overview, detail_views | |
# create overview and detail markdown views for the article | |
detail_views = create_detail_views(article) | |
overview = create_overview(article) | |
return overview, detail_views | |
def process_studies( | |
domain, | |
file_objs): | |
for file_obj in file_objs: | |
process_study(domain,file_obj,None) | |
return gr.update(value=create_md_tables(app_data["articles"])) | |
def create_md_tables(articles): | |
''' | |
create markdown tables for the articles. | |
''' | |
md_text = "" | |
md_text += "| Article Name | Authors | Domain | Upload Time |\n| --- | --- | --- | --- |\n" | |
for name, article in articles.items(): | |
md_table = f"| {name} | {article['Authors']} |{article['domain']} | {article['upload_time']} | \n" | |
md_text += md_table | |
return md_text | |
def update_article_segment(article,device_ifu): | |
# get the key content between article objective and discussion | |
raw_content = article["raw"] | |
index_discussion = raw_content.lower().index("discussion") if "discussion" in raw_content.lower() else len(raw_content) | |
# get the meta data | |
meta_content = raw_content[:index_discussion] | |
abstract, next_content = get_key_content(raw_content,"objective","key") # article Liu does not have objective and key but has introduction. | |
introduction, next_content = get_key_content(next_content,"key","methods") | |
materials_and_methods, next_content = get_key_content(next_content,"methods","results") | |
results, _ = get_key_content(next_content,"results","discussion") | |
# update the article object | |
article.update({ | |
"Abstract": abstract, | |
"Introduction": introduction, | |
"Material and Methods": materials_and_methods, | |
"Results": results, | |
"Meta Content": meta_content, | |
"IFU": get_ifu(device_ifu), | |
"tables": "" | |
}) | |
# add the key content as an aggregation of the other sections | |
article.update({ | |
"key_content": article["Abstract"] + article["Material and Methods"] + article["Results"], | |
}) | |
# add the recognized logic to the article | |
update_logic(article) | |
# one thing to notice here, due to the fact that update_article_segment function perform direct change on the article object, | |
# there is no need to re-assign the article object to the same variable name | |
try: | |
pre_loop = asyncio.new_event_loop() | |
pre_loop.run_until_complete(get_segments(article,article_prompts)) | |
pre_loop.close() | |
except: | |
pre_loop = asyncio.get_event_loop() | |
tasks = [] | |
tasks.append(get_segments(article,article_prompts)) | |
asyncio.gather(*tasks,return_exceptions=True) | |
# need to review this. | |
async def get_segments(article,prompts): | |
tasks = [] | |
for name,p in prompts.items(): | |
prompt = ChatPromptTemplate.from_messages([ | |
("human",article["Meta Content"]), | |
("system","From the text above "+p), | |
]) | |
chain = prompt | llm | |
tasks.append(async_generate(article,name,chain)) | |
await asyncio.gather(*tasks) | |
def refresh(): | |
''' | |
this function refresh the application data from the cloud backend | |
''' | |
init_app_data() | |
article = app_data["current_article"] | |
if not article: | |
return "No file or content provided" | |
process_prompts(article) | |
detail_views = create_detail_views(article) | |
overview = create_overview(article) | |
update_article(article=article) | |
return overview, detail_views,gr.update(choices=list(app_data["articles"].keys())) | |
def create_overview(article): | |
md_text = f"## Overview\n\n" | |
overview_components = article["extraction"]["overview"] | |
for component in overview_components: # command name removed | |
md_text += article[component] + "\n\n" if component in article else "no content found\n\n" | |
return gr.update(value=md_text) | |
def pre_view(content): | |
if "Table Heading" in content: # remove table heading | |
content = content.replace("Table Heading","") | |
# remove the line with ariticle id | |
content = content.split("\n") | |
content = [c for c in content if "article id" not in c.lower()] | |
#get the first line and only keep the alphanumeric characters | |
text = content.split("\n") | |
text[0] = "###" + "".join([c for c in text[0] if c.isalnum()]) | |
return "\n".join(text).replace('"', '') | |
def create_detail_views(article): | |
md_text = "## Performance\n\n" | |
assessments = ["clinical","radiologic","safety","other"] | |
performance_tables = ["clin-perfFUtable-FIN","rad-perfFUtable-FIN","saf-Futable-FIN","oth-perfFUtable-FIN"] | |
# add performance | |
for t,a in zip(performance_tables,assessments): | |
if t in article: | |
md_text += f"### {a.capitalize()}\n\n" | |
md_text += article[t] | |
return gr.update(value=md_text) | |
def get_key_content(text:str,start,end:str,case_sensitive:bool=False): # not getting the materials and methods | |
''' | |
this function extract the content between start and end | |
and return the content in between. If no start or end is | |
found, the function will return the empty string. | |
Parameters | |
---------- | |
text : str | |
text of the article | |
start : list | |
list of start substrings | |
end : list | |
list of end substrings | |
Returns | |
------- | |
str | |
content between start and end | |
''' | |
# if not case_sensitive: | |
text = text.lower() | |
end = end.lower() | |
if type(start) is str: | |
start = start.lower() | |
start_index = text.find(start) | |
else: | |
start_index = start | |
end_index = text.find(end) | |
# if the start is not found, set the start as the beginning of the text | |
if start_index == -1: | |
start_index = 0 | |
# if the end is not found, return the from the start to the end of the text for both | |
# the searched text and the remaining text | |
if end_index == -1: | |
end_index = 0 | |
return text[start_index:],text[start_index:] | |
# return the searched text and the remaining text | |
return text[start_index:end_index],text[end_index:] | |
def get_articles(update_local=True): | |
''' | |
this function return the list of articles | |
Parameters | |
---------- | |
update_local : bool, optional | |
update the local memory, by default True | |
Returns | |
------- | |
list | |
list of articles | |
''' | |
articles = get_table("articles") | |
if update_local: | |
app_data["articles"] = list_dict_to_dict(articles) | |
return articles | |
def get_article(domain,name): | |
''' | |
this function return the article object | |
Parameters | |
---------- | |
domain : str | |
subject domain of the article | |
name : str | |
name of the article | |
Returns | |
------- | |
dict | |
article object | |
''' | |
article = get_item("articles",{"domain":domain,"name":name}) | |
return article | |
def add_article(domain,file,add_to_s3=True, add_to_local=True, file_object=True): | |
''' | |
this function receive the domain name and file obj | |
and add the article to the cloud, s3 and local memory | |
Parameters | |
---------- | |
domain : str | |
subject domain of the article | |
file_obj : file object | |
file object of the article | |
add_to_s3 : bool, optional | |
add article to s3 bucket, by default True | |
add_to_local : bool, optional | |
add article to local memory, by default True | |
Returns | |
------- | |
dict | |
article object | |
''' | |
if type(file) is str: | |
content = file | |
filename = file | |
upload_file(file,default_s3_bucket,filename) | |
else: | |
# extract the content from the pdf file | |
content, _ = read_pdf(file) | |
if "\\" in file.name: | |
filename = file.name.split("\\")[-1] | |
elif "/" in file.name: | |
filename = file.name.split("/")[-1] | |
else: | |
filename = file.name | |
# upload the article to s3 | |
pdf_obj = open(file.name, 'rb') | |
upload_fileobj(pdf_obj,default_s3_bucket,filename) | |
pdf_obj.close() | |
article ={ | |
"domain":domain, | |
"name":filename, | |
"raw":content, | |
"upload_time":datetime.now().strftime("%Y-%m-%d %H:%M:%S"), | |
} | |
if add_to_local: | |
app_data["articles"][article["name"]]=article | |
res = post_item("articles",article) | |
if "Error" in res: | |
print(res["Error"]) | |
return res | |
return article | |
def remove_article(domain,name,remove_from_s3=True, remove_from_local=True): | |
''' | |
this function remove the article from the cloud, s3 and local memory | |
Parameters | |
---------- | |
domain : str | |
subject domain of the article | |
name : str | |
name of the article | |
remove_from_s3 : bool, optional | |
remove article from s3 bucket, by default True | |
remove_from_local : bool, optional | |
remove article from local memory, by default True | |
Returns | |
------- | |
dict | |
article object | |
''' | |
delete_item("articles",{"domain":domain,"name":name}) | |
if remove_from_s3: | |
delete_file(domain,name) | |
if remove_from_local: | |
del app_data["articles"][name] | |
pass | |
delete_item("articles",{"domain":domain,"name":name}) | |
return True | |
def update_article(article,file_obj=None,update_local=True): | |
''' | |
this function receive the article object and update the article | |
to the cloud, s3 and local memory | |
Parameters | |
---------- | |
article : dict | |
article object | |
file_obj : file object, optional | |
file object of the article, by default None | |
update_local : bool, optional | |
update article to local memory, by default True | |
Returns | |
------- | |
dict | |
article object | |
''' | |
if file_obj: | |
upload_fileobj(file_obj,article["domain"],article["name"]) | |
if update_local: | |
app_data["articles"][article["name"]] = article | |
post_item("articles",article) | |
return article | |
def select_overview_prompts(article): | |
valid_prompts = set() | |
for t in app_data["terms"]: | |
# select overview prompts | |
if validate_term(article,t,"overview"): | |
# add the prompts to the memory | |
valid_prompts.update(t["instruction"]) | |
print(valid_prompts) | |
sorted_prompts = sorted(valid_prompts,key=lambda prompt:app_data["prompts"][prompt]["section_sequence"]) | |
article["extraction"]["overview"] = sorted_prompts | |
return {p:app_data["prompts"][p] for p in valid_prompts} | |
def select_performance_prompts(article,performance_assessment): | |
valid_terms = [] | |
search_text = article["key_content"]+article["Authors"]+article["Acceptance Month"]+article["Acceptance Year"]+"\n".join(article["tables"]) | |
search_text = search_text.lower() | |
for t in app_data["terms"]: | |
if validate_term(article,t,performance_assessment): | |
# add the prompts to the memory | |
valid_terms.append(t) | |
# print("valid performance terms",valid_terms) | |
valid_prompts = {} | |
for t in valid_terms: | |
if any([p not in valid_prompts for p in t["instruction"]]): | |
for p in t["instruction"]: | |
prompt = app_data["prompts"][p] | |
valid_prompts[p] = prompt | |
if "term" not in valid_prompts[p]: | |
valid_prompts[p]["term"] = [t] | |
else: | |
valid_prompts[p]["term"].append(t) | |
if performance_assessment not in article["extraction"]: | |
article["extraction"][performance_assessment] = set() | |
article["extraction"][performance_assessment].add(prompt["name"]) | |
# print("valid performance prompts: ",valid_prompts) | |
return valid_prompts | |
def update_logic(article): | |
article["logic"] = { | |
"group":article["key_content"].lower().count("group")>=3, | |
"preoperative":article["key_content"].lower().count("preoperative")>=2, | |
"chain id":[i for i in range(6)] | |
} | |
if not article["logic"]["group"]: | |
article["logic"]["chain id"].remove(1) | |
if not article["logic"]["preoperative"]: | |
article["logic"]["chain id"].remove(3) | |
def process_prompts(article): # function overly complicated. need to be simplified. | |
''' | |
process_prompts function receive the article identify the prompts to be used, | |
and traverse through the prompts and article to extract the content from the article | |
The prompts were selected based on the terms and the article attributes | |
Parameters | |
---------- | |
article : dict | |
article object | |
terms : list | |
list of terms | |
prompts : list | |
list of prompts | |
Returns | |
------- | |
list | |
list of prompts selected for use on the article | |
''' | |
article["extraction"] = {} | |
overview_prompts = select_overview_prompts(article) | |
performance_assessments = ["clinical","radiologic","safety","other"] | |
performance_prompts = {} | |
for assessment in performance_assessments: | |
performance_prompts[assessment] = select_performance_prompts(article,assessment) | |
overview = asyncio.new_event_loop() | |
overview.run_until_complete(execute_concurrent(article,overview_prompts)) | |
overview.close() | |
for assessment in performance_assessments: | |
performance = asyncio.new_event_loop() | |
performance.run_until_complete(execute_concurrent(article,performance_prompts[assessment])) | |
performance.close() | |
def validate_term(article,term,assessment): | |
# validate if the term is used for the right anatomic region for the article | |
if term["region"].lower() != "all" and term["region"].lower() != article["domain"].lower(): | |
return False | |
if assessment == "overview" and term["assessment"] == "overview": | |
return True | |
# validate if the term is used for overview | |
if term["assessment"] == assessment: | |
# validate if the term is used for performance | |
key_text = (article["key_content"]+article["Authors"]+article["Acceptance Month"]+article["Acceptance Year"]+"\n".join(article["tables"])) | |
key_text = key_text.replace("/n"," ") | |
key_text = key_text.lower() | |
keywords = [kw.strip().lower() for kw in term["indication_terms"].split(",")] | |
return all([kw in key_text for kw in keywords]) | |
return False | |
def keyword_search(keywords,full_text): | |
keywords_result = {} | |
for k in keywords: | |
if type(k) is tuple or type(k) is list or type(k) is set: | |
keywords_result[k]=any([keyword_search(kw,full_text) for kw in k]) | |
else: | |
keywords_result[k]=k in full_text | |
return keywords_result | |
def post_process(article): | |
post_inputs = {} | |
for assessment,segements in article["extraction"].items(): | |
if assessment == "overview": | |
continue | |
post_inputs[assessment] = "\n".join([article[s] for s in segements]) | |
template = ChatPromptTemplate.from_messages([ | |
("human","{text}"), | |
("system","From the text above {instruction}"), | |
]) | |
chain = template | llm | |
post_loop = asyncio.new_event_loop() | |
post_loop.run_until_complete(run_post(article,post_inputs,chain)) | |
async def run_post(article,post_inputs,chain): | |
tasks = [] | |
for assessment,post_input in post_inputs.items(): | |
name = app_data["prompts_agg"][assessment]["name"] | |
input_variables = {"text":post_input,"instruction":" ".join(app_data["prompts_agg"][assessment]["chain"])} | |
article["extraction"][assessment].add(name) | |
tasks.append(async_generate(article,name,chain,input_variables=input_variables)) | |
await asyncio.gather(*tasks) |