|
import gradio as gr |
|
import torch |
|
from transformers import AutoTokenizer, AutoModelForSequenceClassification, pipeline |
|
|
|
|
|
model = AutoModelForSequenceClassification.from_pretrained("calerio-uva/roberta-adr-model") |
|
tokenizer = AutoTokenizer.from_pretrained("calerio-uva/roberta-adr-model") |
|
|
|
|
|
ner = pipeline( |
|
"ner", |
|
model="d4data/biomedical-ner-all", |
|
tokenizer="d4data/biomedical-ner-all", |
|
aggregation_strategy="simple" |
|
) |
|
|
|
|
|
SYMPTOM_TAGS = {"sign_symptom", "symptom"} |
|
DISEASE_TAGS = {"disease_disorder"} |
|
MED_TAGS = {"medication", "administration", "therapeutic_procedure"} |
|
|
|
|
|
def dedupe_and_filter(tokens): |
|
seen, out = set(), [] |
|
for tok in tokens: |
|
w = tok.strip() |
|
if len(w) < 3: |
|
continue |
|
lw = w.lower() |
|
if lw not in seen: |
|
seen.add(lw) |
|
out.append(w) |
|
return out |
|
|
|
def classify_adr(text: str): |
|
print("π [DEBUG] Running classify_adr", flush=True) |
|
|
|
|
|
clean = text.strip().replace("nan", "").replace(" ", " ") |
|
print("π [DEBUG] clean[:50]:", clean[:50], "...", flush=True) |
|
|
|
|
|
inputs = tokenizer(clean, return_tensors="pt", truncation=True, padding=True, max_length=512) |
|
with torch.no_grad(): |
|
logits = model(**inputs).logits |
|
probs = torch.softmax(logits, dim=1)[0].cpu().numpy() |
|
|
|
|
|
ents = ner(clean) |
|
print("π [DEBUG] raw ents:", [(e["entity_group"], e["word"], e["start"], e["end"]) for e in ents], flush=True) |
|
|
|
|
|
spans = [] |
|
for ent in ents: |
|
grp, start, end, score = ent["entity_group"].lower(), ent["start"], ent["end"], ent.get("score", 1.0) |
|
if spans and spans[-1]["group"] == grp and start <= spans[-1]["end"]: |
|
spans[-1]["end"] = max(spans[-1]["end"], end) |
|
spans[-1]["score"] = max(spans[-1]["score"], score) |
|
else: |
|
spans.append({"group": grp, "start": start, "end": end, "score": score}) |
|
print("π [DEBUG] merged spans:", spans, flush=True) |
|
|
|
|
|
for s in spans: |
|
if s["group"] in MED_TAGS: |
|
st, en = s["start"], s["end"] |
|
|
|
while en < len(clean) and clean[en].isalpha(): |
|
en += 1 |
|
s["end"] = en |
|
|
|
|
|
spans = [s for s in spans if s["score"] >= 0.6] |
|
print("π [DEBUG] postβfilter spans:", spans, flush=True) |
|
|
|
|
|
tokens = [clean[s["start"]:s["end"]] for s in spans] |
|
print("π [DEBUG] tokens:", tokens, flush=True) |
|
|
|
|
|
symptoms = dedupe_and_filter([t for t, s in zip(tokens, spans) if s["group"] in SYMPTOM_TAGS]) |
|
diseases = dedupe_and_filter([t for t, s in zip(tokens, spans) if s["group"] in DISEASE_TAGS]) |
|
medications = dedupe_and_filter([t for t, s in zip(tokens, spans) if s["group"] in MED_TAGS]) |
|
|
|
|
|
if probs[1] > 0.9: |
|
comment = "β High confidence this is a severe ADR." |
|
elif probs[1] > 0.5: |
|
comment = "β οΈ Borderline case β may be severe." |
|
else: |
|
comment = "β
Likely not severe." |
|
|
|
return ( |
|
f"Not Severe (0): {probs[0]:.3f}\nSevere (1): {probs[1]:.3f}", |
|
"\n".join(symptoms) or "None detected", |
|
"\n".join(diseases) or "None detected", |
|
"\n".join(medications) or "None detected", |
|
comment |
|
) |
|
|
|
|
|
demo = gr.Interface( |
|
fn=classify_adr, |
|
inputs=gr.Textbox(lines=4, label="ADR Description"), |
|
outputs=[ |
|
gr.Textbox(label="Predicted Probabilities"), |
|
gr.Textbox(label="Symptoms"), |
|
gr.Textbox(label="Diseases or Conditions"), |
|
gr.Textbox(label="Medications"), |
|
gr.Textbox(label="Interpretation"), |
|
], |
|
title="ADR Severity & NER Classifier", |
|
description="Paste an ADR description to classify severity and extract symptoms, diseases & medications.", |
|
allow_flagging="never" |
|
) |
|
|
|
if __name__ == "__main__": |
|
demo.launch() |