import os import torch import gradio as gr import numpy as np import matplotlib.pyplot as plt from transformers import AutoTokenizer, AutoModelForSequenceClassification, pipeline import shap from shap.maskers import Text from shap.explainers import Permutation # Device configuration device = torch.device("cpu") print(f"✅ Running on device: {device}") # Load model and tokenizer model = AutoModelForSequenceClassification.from_pretrained("calerio-uva/roberta-adr-model").to(device).eval() tokenizer = AutoTokenizer.from_pretrained("calerio-uva/roberta-adr-model") # NER pipeline ner = pipeline( "ner", model="d4data/biomedical-ner-all", tokenizer="d4data/biomedical-ner-all", aggregation_strategy="simple", device=-1 ) # SHAP setup clf_pipeline = pipeline( "text-classification", model=model, tokenizer=tokenizer, top_k=None, device=-1 ) def shap_predict(texts): texts = [str(t) for t in texts] results = clf_pipeline(texts, truncation=True, padding=True, max_length=512) scores = [] for i, text in enumerate(texts): if isinstance(results[i], dict): scores.append([1 - results[i]['score'], results[i]['score']]) else: scores.append([entry['score'] for entry in results[i]]) return np.array(scores) masker = Text(tokenizer) explainer = Permutation(shap_predict, masker, output_names=["Not Severe", "Severe"]) SYMPTOM_TAGS = {"sign_symptom", "symptom"} DISEASE_TAGS = {"disease_disorder"} MED_TAGS = {"medication", "administration", "therapeutic_procedure"} def dedupe_and_filter(tokens): seen, out = set(), [] for tok in tokens: w = tok.strip() if len(w) < 3: continue lw = w.lower() if lw not in seen: seen.add(lw) out.append(w) return out def classify_adr(text, explain=False): clean = text.strip().replace("nan", "").replace(" ", " ")[:512] # Predict inputs = tokenizer(clean, return_tensors="pt", truncation=True, padding=True, max_length=512).to(device) with torch.no_grad(): logits = model(**inputs).logits probs = torch.softmax(logits, dim=1)[0].cpu().numpy() # NER ents = ner(clean) spans = [] for ent in ents: grp, start, end, score = ent["entity_group"].lower(), ent["start"], ent["end"], ent.get("score", 1.0) if spans and spans[-1]["group"] == grp and start <= spans[-1]["end"] + 1: spans[-1]["end"] = max(spans[-1]["end"], end) spans[-1]["score"] = max(spans[-1]["score"], score) else: spans.append({"group": grp, "start": start, "end": end, "score": score}) for s in spans: if s["group"] in MED_TAGS: en = s["end"] while en < len(clean) and clean[en].isalpha(): en += 1 s["end"] = en spans = [s for s in spans if s["score"] >= 0.6] tokens = [] for s in spans: chunk = clean[s["start"]:s["end"]].strip() if len(chunk) >= 3: tokens.append((chunk, s["group"])) symptoms = dedupe_and_filter([t for t, g in tokens if g in SYMPTOM_TAGS]) diseases = dedupe_and_filter([t for t, g in tokens if g in DISEASE_TAGS]) medications = dedupe_and_filter([t for t, g in tokens if g in MED_TAGS]) if probs[1] > 0.9: comment = "❗ High confidence this is a severe ADR." elif probs[1] > 0.5: comment = "⚠️ Borderline case — may be severe." else: comment = "✅ Likely not severe." # SHAP explanation as image shap_path = None if explain: try: shap_values = explainer([clean], max_evals=min(400, len(clean.split()) * 5)) plt.figure() shap.plots.bar(shap_values[0], show=False) shap_path = "/tmp/shap_expl.png" plt.savefig(shap_path, bbox_inches="tight") plt.close() except Exception as e: print(f"[SHAP Error] {e}") shap_path = None return ( f"Not Severe (0): {probs[0]:.3f}\nSevere (1): {probs[1]:.3f}", "\n".join(symptoms) or "None detected", "\n".join(diseases) or "None detected", "\n".join(medications) or "None detected", comment, shap_path ) demo = gr.Interface( fn=classify_adr, inputs=[ gr.Textbox(lines=5, label="ADR Description"), gr.Checkbox(label="Generate SHAP Explanation (VERY slow)", value=False) ], outputs=[ gr.Textbox(label="Predicted Probabilities"), gr.Textbox(label="Symptoms"), gr.Textbox(label="Diseases or Conditions"), gr.Textbox(label="Medications"), gr.Textbox(label="Interpretation"), gr.Image(label="SHAP Explanation") ], title="ADR Severity & NER Classifier 2", description="Paste an ADR description to classify severity, extract symptoms, diseases, medications, and visualize SHAP explanations.", allow_flagging="never" ) if __name__ == "__main__": demo.launch(server_name="0.0.0.0", server_port=int(os.environ.get("PORT", 7860)))