calerio's picture
Update app.py
1f15c83 verified
import os
import torch
import gradio as gr
import numpy as np
import matplotlib.pyplot as plt
from transformers import AutoTokenizer, AutoModelForSequenceClassification, pipeline
import shap
from shap.maskers import Text
from shap.explainers import Permutation
# Device configuration
device = torch.device("cpu")
print(f"βœ… Running on device: {device}")
# Load model and tokenizer
model = AutoModelForSequenceClassification.from_pretrained("calerio-uva/roberta-adr-model").to(device).eval()
tokenizer = AutoTokenizer.from_pretrained("calerio-uva/roberta-adr-model")
# NER pipeline
ner = pipeline(
"ner",
model="d4data/biomedical-ner-all",
tokenizer="d4data/biomedical-ner-all",
aggregation_strategy="simple",
device=-1
)
# SHAP setup
clf_pipeline = pipeline(
"text-classification",
model=model,
tokenizer=tokenizer,
top_k=None,
device=-1
)
def shap_predict(texts):
texts = [str(t) for t in texts]
results = clf_pipeline(texts, truncation=True, padding=True, max_length=512)
scores = []
for i, text in enumerate(texts):
if isinstance(results[i], dict):
scores.append([1 - results[i]['score'], results[i]['score']])
else:
scores.append([entry['score'] for entry in results[i]])
return np.array(scores)
masker = Text(tokenizer)
explainer = Permutation(shap_predict, masker, output_names=["Not Severe", "Severe"])
SYMPTOM_TAGS = {"sign_symptom", "symptom"}
DISEASE_TAGS = {"disease_disorder"}
MED_TAGS = {"medication", "administration", "therapeutic_procedure"}
def dedupe_and_filter(tokens):
seen, out = set(), []
for tok in tokens:
w = tok.strip()
if len(w) < 3:
continue
lw = w.lower()
if lw not in seen:
seen.add(lw)
out.append(w)
return out
def classify_adr(text, explain=False):
clean = text.strip().replace("nan", "").replace(" ", " ")[:512]
# Predict
inputs = tokenizer(clean, return_tensors="pt", truncation=True, padding=True, max_length=512).to(device)
with torch.no_grad():
logits = model(**inputs).logits
probs = torch.softmax(logits, dim=1)[0].cpu().numpy()
# NER
ents = ner(clean)
spans = []
for ent in ents:
grp, start, end, score = ent["entity_group"].lower(), ent["start"], ent["end"], ent.get("score", 1.0)
if spans and spans[-1]["group"] == grp and start <= spans[-1]["end"] + 1:
spans[-1]["end"] = max(spans[-1]["end"], end)
spans[-1]["score"] = max(spans[-1]["score"], score)
else:
spans.append({"group": grp, "start": start, "end": end, "score": score})
for s in spans:
if s["group"] in MED_TAGS:
en = s["end"]
while en < len(clean) and clean[en].isalpha():
en += 1
s["end"] = en
spans = [s for s in spans if s["score"] >= 0.6]
tokens = []
for s in spans:
chunk = clean[s["start"]:s["end"]].strip()
if len(chunk) >= 3:
tokens.append((chunk, s["group"]))
symptoms = dedupe_and_filter([t for t, g in tokens if g in SYMPTOM_TAGS])
diseases = dedupe_and_filter([t for t, g in tokens if g in DISEASE_TAGS])
medications = dedupe_and_filter([t for t, g in tokens if g in MED_TAGS])
if probs[1] > 0.9:
comment = "❗ High confidence this is a severe ADR."
elif probs[1] > 0.5:
comment = "⚠️ Borderline case β€” may be severe."
else:
comment = "βœ… Likely not severe."
# SHAP explanation as image
shap_path = None
if explain:
try:
shap_values = explainer([clean], max_evals=min(400, len(clean.split()) * 5))
plt.figure()
shap.plots.bar(shap_values[0], show=False)
shap_path = "/tmp/shap_expl.png"
plt.savefig(shap_path, bbox_inches="tight")
plt.close()
except Exception as e:
print(f"[SHAP Error] {e}")
shap_path = None
return (
f"Not Severe (0): {probs[0]:.3f}\nSevere (1): {probs[1]:.3f}",
"\n".join(symptoms) or "None detected",
"\n".join(diseases) or "None detected",
"\n".join(medications) or "None detected",
comment,
shap_path
)
demo = gr.Interface(
fn=classify_adr,
inputs=[
gr.Textbox(lines=5, label="ADR Description"),
gr.Checkbox(label="Generate SHAP Explanation (VERY slow)", value=False)
],
outputs=[
gr.Textbox(label="Predicted Probabilities"),
gr.Textbox(label="Symptoms"),
gr.Textbox(label="Diseases or Conditions"),
gr.Textbox(label="Medications"),
gr.Textbox(label="Interpretation"),
gr.Image(label="SHAP Explanation")
],
title="ADR Severity & NER Classifier 2",
description="Paste an ADR description to classify severity, extract symptoms, diseases, medications, and visualize SHAP explanations.",
allow_flagging="never"
)
if __name__ == "__main__":
demo.launch(server_name="0.0.0.0", server_port=int(os.environ.get("PORT", 7860)))