Spaces:
Sleeping
Sleeping
import streamlit as st | |
import gradio as gr | |
import shap | |
import numpy as np | |
import scipy as sp | |
import torch | |
import transformers | |
from transformers import pipeline | |
from transformers import AutoTokenizer, AutoModelForSequenceClassification, AutoModelForTokenClassification | |
import matplotlib.pyplot as plt | |
import sys | |
import csv | |
csv.field_size_limit(sys.maxsize) | |
device = "cuda:0" if torch.cuda.is_available() else "cpu" | |
tokenizer = AutoTokenizer.from_pretrained("jschwaller/ADRv2024") | |
model = AutoModelForSequenceClassification.from_pretrained("jschwaller/ADRv2024") | |
# Build a pipeline object for predictions | |
pred = transformers.pipeline("text-classification", model=model, | |
tokenizer=tokenizer, top_k=None) | |
explainer = shap.Explainer(pred) | |
ner_tokenizer = AutoTokenizer.from_pretrained("d4data/biomedical-ner-all") | |
ner_model = AutoModelForTokenClassification.from_pretrained("d4data/biomedical-ner-all") | |
ner_pipe = pipeline("ner", model=ner_model, tokenizer=ner_tokenizer, aggregation_strategy="simple") # pass device=0 if using gpu | |
# | |
entity_colors = { | |
'Severity': '#E63946', # a vivid red | |
'Sign_symptom': '#2A9D8F', # a deep teal | |
'Medication': '#457B9D', # a dusky blue | |
'Age': '#F4A261', # a sandy orange | |
'Sex': '#F4A261', # same sandy orange for consistency with 'Age' | |
'Diagnostic_procedure': '#9C6644', # a brown | |
'Biological_structure': '#BDB2FF', # a light pastel purple | |
} | |
def adr_predict(x): | |
encoded_input = tokenizer(x, return_tensors='pt') | |
output = model(**encoded_input) | |
scores = output[0][0].detach() | |
scores = torch.nn.functional.softmax(scores) | |
shap_values = explainer([str(x).lower()]) | |
local_plot = shap.plots.text(shap_values[0], display=False) | |
res = ner_pipe(x) | |
htext = "" | |
prev_end = 0 | |
for entity in res: | |
start = entity['start'] | |
end = entity['end'] | |
word = entity['word'].replace("##", "") | |
color = entity_colors[entity['entity_group']] | |
htext += f"{x[prev_end:start]}<mark style='background-color:{color};'>{word}</mark>" | |
prev_end = end | |
htext += x[prev_end:] | |
return {"Severe Reaction": float(scores.numpy()[1]), "Non-severe Reaction": float(scores.numpy()[0])}, local_plot, htext | |
def main(prob1): | |
text = str(prob1).lower() | |
obj = adr_predict(text) | |
return obj[0], obj[1], obj[2] | |
# Define HTML for the legend | |
legend_html = """ | |
<div style='margin-top: 20px; color: white;'> <!-- Ensure the legend text is white for visibility --> | |
<h3>NER Legend</h3> | |
<ul style='list-style-type:none; padding-left: 0;'> <!-- Remove padding from the list --> | |
""" | |
for entity, color in entity_colors.items(): | |
legend_html += f"<li><span style='color: white; background-color: {color}; padding: 5px 10px; margin-right: 5px; border-radius: 5px;'>{entity}</span></li>" | |
legend_html += "</ul></div>" | |
# Create a Gradio HTML component to display the legend | |
ner_legend = gr.HTML(value=legend_html) | |
title = "Welcome to **ADR Tracker**" | |
description1 = "This app takes text (up to a few sentences) and predicts to what extent the text describes severe (or non-severe) adverse reaction to medications. Please do NOT use for medical diagnosis." | |
css = """ | |
body { font-family: 'Roboto', sans-serif; background-color: #333; color: #163E64; } | |
h1, h2, h3, h4, h5, h6, p, label, .markdown { color: #163E64; } /* Ensuring that all text elements are consistently light blue */ | |
.textbox { width: 100%; border-radius: 10px; border: 1px solid #ccc; background-color: white; color: black; } | |
.button { background-color: #FF6347; color: white; border: none; border-radius: 10px; padding: 10px 20px; cursor: pointer; } | |
""" | |
with gr.Blocks(css=css) as demo: | |
gr.Markdown(f"## {title}") | |
gr.Markdown(description1) | |
gr.Markdown("---") | |
prob1 = gr.Textbox(label="Enter Your Text Here:", lines=2, placeholder="Type it here...") | |
submit_btn = gr.Button("Analyze") | |
with gr.Row(): | |
with gr.Column(visible=True): | |
label = gr.Label(label="Predicted Label") | |
with gr.Column(visible=True): | |
local_plot = gr.HTML(label='Shap:') | |
htext = gr.HTML(label="NER") | |
submit_btn.click( | |
main, | |
[prob1], | |
[label, local_plot, htext], | |
api_name="adr" | |
) | |
# Display the NER Legend below the buttons | |
ner_legend # Assuming you've defined this component above as shown | |
with gr.Row(): | |
gr.Markdown("### Click on any of the examples below to see how it works:") | |
gr.Examples([["A 35 year-old female had suicidal ideation after taking Prednisone."], | |
["A 23 year-old male had minor nausea after taking Acetaminophen."]], | |
[prob1], [label, local_plot, htext], main, cache_examples=True) | |
demo.launch() | |