|
|
|
|
|
""" |
|
|
@author: adnan-sadi |
|
|
""" |
|
|
|
|
|
from huggingface_hub import login |
|
|
from transformers import AutoTokenizer, AutoModelForSequenceClassification |
|
|
from transformers import set_seed |
|
|
import gradio as gr |
|
|
import torch |
|
|
import numpy as np |
|
|
import os |
|
|
|
|
|
seed = 17 |
|
|
set_seed(seed) |
|
|
|
|
|
token = os.getenv("hf_token") |
|
|
login(token) |
|
|
|
|
|
|
|
|
|
|
|
labels = ['URTI', 'HIV (initial infection)', 'Pneumonia', 'Chronic rhinosinusitis', 'Viral pharyngitis', 'Anemia', |
|
|
'Atrial fibrillation', 'Allergic sinusitis', 'Laryngospasm', 'Cluster headache', 'Anaphylaxis', |
|
|
'Spontaneous pneumothorax', 'Acute pulmonary edema', 'Tuberculosis', 'Myasthenia gravis', 'Panic attack', |
|
|
'Scombroid food poisoning', 'Epiglottitis', 'Inguinal hernia', 'Boerhaave', 'Pancreatic neoplasm', 'Bronchitis', |
|
|
'SLE', 'Acute laryngitis', 'Unstable angina', 'Bronchiectasis', 'Possible NSTEMI / STEMI', 'Chagas', |
|
|
'Localized edema', 'Sarcoidosis', 'Spontaneous rib fracture', 'GERD', 'Bronchospasm / acute asthma exacerbation', |
|
|
'Acute COPD exacerbation / infection', 'Guillain-Barré syndrome', 'Influenza', 'Pulmonary embolism', |
|
|
'Stable angina', 'Pericarditis', 'Acute rhinosinusitis', 'Whooping cough', 'Myocarditis', 'Acute dystonic reactions', |
|
|
'Pulmonary neoplasm', 'Acute otitis media', 'PSVT', 'Croup', 'Ebola', 'Bronchiolitis'] |
|
|
|
|
|
label2id = {label:idx for idx, label in enumerate(labels)} |
|
|
id2label = {idx:label for label, idx in label2id.items()} |
|
|
|
|
|
|
|
|
|
|
|
def get_model_and_tokenizer(model_name): |
|
|
model_dict = { |
|
|
"bert-base" : "AdnanSadi/Bert_DDXPlus_1", |
|
|
"distilbert-base" : "AdnanSadi/DistilBert_DDXPlus_2", |
|
|
"roberta-base": "AdnanSadi/Roberta_DDXPlus_1", |
|
|
"bds-bert-base": "AdnanSadi/BioDisSumBert_DDXPlus_1", |
|
|
"bert-sp-mtd": "AdnanSadi/Bert_DDXPlus_2", |
|
|
"distilbert-sp-mtd": "AdnanSadi/DistilBert_DDXPlus_3", |
|
|
"roberta-sp-mtd" : "AdnanSadi/Roberta_DDXPlus_2", |
|
|
"bds-bert-sp-mtd": "AdnanSadi/BioDisSumBert_DDXPlus_2", |
|
|
"Choose": "AdnanSadi/Roberta_DDXPlus_2" |
|
|
} |
|
|
|
|
|
model_path = model_dict[model_name] |
|
|
tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast= True) |
|
|
model = AutoModelForSequenceClassification.from_pretrained(model_path) |
|
|
|
|
|
return tokenizer, model |
|
|
|
|
|
|
|
|
def get_diagnosis(age, sex, medhist, symptoms, model_name, threshold): |
|
|
|
|
|
tokenizer, model = get_model_and_tokenizer(model_name) |
|
|
|
|
|
|
|
|
text = f"""The following is a list of medical history and symptoms described by a patient.""" |
|
|
text += f"""\nSex: {sex}, Age: {age}""" |
|
|
text += f"""\nMedical History:\n{medhist}""" |
|
|
text += f"""\nSymptoms:\n{symptoms}""" |
|
|
|
|
|
encoding = tokenizer(text, truncation=True, return_tensors="pt") |
|
|
|
|
|
outputs = model(**encoding) |
|
|
logits = outputs.logits |
|
|
|
|
|
|
|
|
sigmoid = torch.nn.Sigmoid() |
|
|
probs = sigmoid(logits.squeeze().cpu()) |
|
|
predictions = np.zeros(probs.shape) |
|
|
predictions[np.where(probs >= threshold)] = 1 |
|
|
|
|
|
|
|
|
predicted_probs = probs[predictions == 1] |
|
|
|
|
|
predicted_labels = [id2label[idx] for idx, label in enumerate(predictions) if label == 1.0] |
|
|
|
|
|
output_text = f"""""" |
|
|
|
|
|
for i, pl in enumerate(predicted_labels): |
|
|
output_text += f"""{i+1}. {pl} (Conf. Score: {round(predicted_probs[i].item(), 4)})\n""" |
|
|
|
|
|
|
|
|
return output_text |
|
|
|
|
|
|
|
|
sample_1 = ["Male", "90", f"""- I work in a daycare. |
|
|
- I smoke cigarettes. |
|
|
- I have had a cold in the last 2 weeks. |
|
|
- I have not traveled anywhere in the last 4 weeks.""", |
|
|
f"""- I feel pain. |
|
|
- The pain is sensitive. |
|
|
- I feel pain in my tonsil(R). |
|
|
- I feel pain in my thyroid cartilage. |
|
|
- I feel pain in my palate. |
|
|
- I feel pain in my pharynx. |
|
|
- I feel pain in my under the jaw. |
|
|
- On a scale of 0-10, the pain intensity is 4. |
|
|
- The pain does not radiate to anywhere. |
|
|
- On a scale of 0-10, the pain's location precision is 4. |
|
|
- On a scale of 0-10, the pace at which the pain appear is 2. |
|
|
- I have a cough. |
|
|
- I have noticed a change in the tone of my voice.""", "bds-bert-sp-mtd"] |
|
|
|
|
|
sample_2 = ["Female", "16", f"""- I feel anxious. |
|
|
- I regularly drink coffee or tea. |
|
|
- I consume energy drinks regularly. |
|
|
- I regularly take stimulant drugs. |
|
|
- I have not traveled anywhere in the last 4 weeks.""", |
|
|
f"""- I feel pain. |
|
|
- The pain is burning. |
|
|
- I feel pain in my back of head. |
|
|
- I feel pain in my forehead. |
|
|
- I feel pain in my temple(R). |
|
|
- On a scale of 0-10, the pain intensity is 6. |
|
|
- The pain does not radiate to anywhere. |
|
|
- On a scale of 0-10, the pain's location precision is 2. |
|
|
- On a scale of 0-10, the pace at which the pain appear is 6. |
|
|
- I am experiencing shortness of breath or difficulty breathing in a significant way. |
|
|
- I feel lightheaded, dizzy, and about to faint. |
|
|
- I feel palpitations.""", "bert-base"] |
|
|
|
|
|
sample_3 = ["Male", "57", f"""- Some members of my family have been diagnosed with myasthenia gravis. |
|
|
- I have not traveled anywhere in the last 4 weeks.""", |
|
|
f"""- I have pain or weakness in my jaw. |
|
|
- I have difficulty articulating words/speaking. |
|
|
- I have a feeling of discomfort/blockage when swallowing. |
|
|
- I am experiencing shortness of breath or difficulty breathing in a significant way. |
|
|
- I feel weakness in both arms and/or both legs.""", "roberta-sp-mtd"] |
|
|
|
|
|
|
|
|
|
|
|
model_names = ["Choose", "bert-base", "distilbert-base", "roberta-base", "bds-bert-base", |
|
|
"bert-sp-mtd", "distilbert-sp-mtd", "roberta-sp-mtd", "bds-bert-sp-mtd"] |
|
|
demo = gr.Blocks() |
|
|
|
|
|
with demo: |
|
|
|
|
|
gr.Markdown(""" |
|
|
# Differential Diagnosis Tool |
|
|
|
|
|
This demo contains the models described in paper [Automatic Differential Diagnosis using Transformer-Based Multi-Label Sequence Classification](https://doi.org/10.48550/arXiv.2408.15827). |
|
|
The models were trained to provide a differential diagnosis based on the medical history and symptoms described by a patient. |
|
|
Please fill out the following form with relevant information, including the age, sex, medical history, and symptoms of the patient. For best results, please provide the symptoms and medical history information as a list. |
|
|
#### For reference, please look over some of the examples provided at the bottom. |
|
|
### Acknowledgments: This project was funded by North South University CTRG. |
|
|
""") |
|
|
|
|
|
with gr.Row(): |
|
|
with gr.Column(): |
|
|
|
|
|
with gr.Row(): |
|
|
age = gr.components.Number(label="Age", interactive= True) |
|
|
gender = gr.components.Dropdown(["Choose", "Male", "Female"], label="Gender", value="Choose", interactive=True) |
|
|
|
|
|
medhist = gr.components.Textbox(label="Medical History", info="a list of patient's medical history", lines=5, interactive=True) |
|
|
symptoms = gr.components.Textbox(label="Symptoms", info="a list of patient's symptoms", lines=5, interactive=True) |
|
|
model_name = gr.components.Dropdown(model_names, label="Model", info="Defaults to Roberta", |
|
|
value="Choose", interactive=True) |
|
|
|
|
|
threshold_value = gr.components.Slider(0, 1, value=0.5, label="Model Confidence Threshold", |
|
|
info="Choose between 0 and 1", interactive=True) |
|
|
|
|
|
with gr.Row(): |
|
|
clear_btn = gr.Button("Clear") |
|
|
submit_btn = gr.Button("Submit", variant="primary") |
|
|
|
|
|
output_box = gr.Textbox(label="Differential Diagnosis Based on Patient Report:", lines=5, interactive=False) |
|
|
|
|
|
gr.Markdown("## Patient Examples") |
|
|
gr.Examples( |
|
|
[sample_1, sample_2, sample_3], |
|
|
[gender, age, medhist, symptoms, model_name], |
|
|
) |
|
|
|
|
|
clear_btn.click(lambda: [0, "Choose", None,None,"Choose",0.2, None], |
|
|
outputs=[age, gender, medhist, symptoms, model_name, threshold_value, output_box]) |
|
|
submit_btn.click(fn = get_diagnosis, inputs=[age, gender, medhist, symptoms, model_name, threshold_value], |
|
|
outputs=output_box) |
|
|
|
|
|
|
|
|
demo.launch(share=False, debug=False) |