calerio's picture
Create app.py
8a41839 verified
raw
history blame contribute delete
4.14 kB
import gradio as gr
import torch
from transformers import AutoTokenizer, AutoModelForSequenceClassification, pipeline
# 1) Classification model
model = AutoModelForSequenceClassification.from_pretrained("calerio-uva/roberta-adr-model")
tokenizer = AutoTokenizer.from_pretrained("calerio-uva/roberta-adr-model")
# 2) Unified NER pipeline
ner = pipeline(
"ner",
model="d4data/biomedical-ner-all",
tokenizer="d4data/biomedical-ner-all",
aggregation_strategy="simple"
)
# 3) Tight tag sets
SYMPTOM_TAGS = {"sign_symptom", "symptom"}
DISEASE_TAGS = {"disease_disorder"}
MED_TAGS = {"medication", "administration", "therapeutic_procedure"}
# 4) Helper: drop <3‑char & dedupe
def dedupe_and_filter(tokens):
seen, out = set(), []
for tok in tokens:
w = tok.strip()
if len(w) < 3:
continue
lw = w.lower()
if lw not in seen:
seen.add(lw)
out.append(w)
return out
def classify_adr(text: str):
print("πŸ” [DEBUG] Running classify_adr", flush=True)
# Clean
clean = text.strip().replace("nan", "").replace(" ", " ")
print("πŸ” [DEBUG] clean[:50]:", clean[:50], "...", flush=True)
# Severity
inputs = tokenizer(clean, return_tensors="pt", truncation=True, padding=True, max_length=512)
with torch.no_grad():
logits = model(**inputs).logits
probs = torch.softmax(logits, dim=1)[0].cpu().numpy()
# Raw NER
ents = ner(clean)
print("πŸ” [DEBUG] raw ents:", [(e["entity_group"], e["word"], e["start"], e["end"]) for e in ents], flush=True)
# 1) Build & merge spans by offsets
spans = []
for ent in ents:
grp, start, end, score = ent["entity_group"].lower(), ent["start"], ent["end"], ent.get("score", 1.0)
if spans and spans[-1]["group"] == grp and start <= spans[-1]["end"]:
spans[-1]["end"] = max(spans[-1]["end"], end)
spans[-1]["score"] = max(spans[-1]["score"], score)
else:
spans.append({"group": grp, "start": start, "end": end, "score": score})
print("πŸ” [DEBUG] merged spans:", spans, flush=True)
# 2) Extend med spans out to full word
for s in spans:
if s["group"] in MED_TAGS:
st, en = s["start"], s["end"]
# extend forward while alphabetic
while en < len(clean) and clean[en].isalpha():
en += 1
s["end"] = en
# 3) Filter by confidence β‰₯0.6
spans = [s for s in spans if s["score"] >= 0.6]
print("πŸ” [DEBUG] post‑filter spans:", spans, flush=True)
# 4) Extract text
tokens = [clean[s["start"]:s["end"]] for s in spans]
print("πŸ” [DEBUG] tokens:", tokens, flush=True)
# Bucket & dedupe
symptoms = dedupe_and_filter([t for t, s in zip(tokens, spans) if s["group"] in SYMPTOM_TAGS])
diseases = dedupe_and_filter([t for t, s in zip(tokens, spans) if s["group"] in DISEASE_TAGS])
medications = dedupe_and_filter([t for t, s in zip(tokens, spans) if s["group"] in MED_TAGS])
# Interpretation
if probs[1] > 0.9:
comment = "❗ High confidence this is a severe ADR."
elif probs[1] > 0.5:
comment = "⚠️ Borderline case β€” may be severe."
else:
comment = "βœ… Likely not severe."
return (
f"Not Severe (0): {probs[0]:.3f}\nSevere (1): {probs[1]:.3f}",
"\n".join(symptoms) or "None detected",
"\n".join(diseases) or "None detected",
"\n".join(medications) or "None detected",
comment
)
# 5) Gradio UI
demo = gr.Interface(
fn=classify_adr,
inputs=gr.Textbox(lines=4, label="ADR Description"),
outputs=[
gr.Textbox(label="Predicted Probabilities"),
gr.Textbox(label="Symptoms"),
gr.Textbox(label="Diseases or Conditions"),
gr.Textbox(label="Medications"),
gr.Textbox(label="Interpretation"),
],
title="ADR Severity & NER Classifier",
description="Paste an ADR description to classify severity and extract symptoms, diseases & medications.",
allow_flagging="never"
)
if __name__ == "__main__":
demo.launch()