File size: 4,143 Bytes
8a41839
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
import gradio as gr
import torch
from transformers import AutoTokenizer, AutoModelForSequenceClassification, pipeline

# 1) Classification model
model = AutoModelForSequenceClassification.from_pretrained("calerio-uva/roberta-adr-model")
tokenizer = AutoTokenizer.from_pretrained("calerio-uva/roberta-adr-model")

# 2) Unified NER pipeline
ner = pipeline(
    "ner",
    model="d4data/biomedical-ner-all",
    tokenizer="d4data/biomedical-ner-all",
    aggregation_strategy="simple"
)

# 3) Tight tag sets
SYMPTOM_TAGS = {"sign_symptom", "symptom"}
DISEASE_TAGS = {"disease_disorder"}
MED_TAGS     = {"medication", "administration", "therapeutic_procedure"}

# 4) Helper: drop <3‑char & dedupe
def dedupe_and_filter(tokens):
    seen, out = set(), []
    for tok in tokens:
        w = tok.strip()
        if len(w) < 3:
            continue
        lw = w.lower()
        if lw not in seen:
            seen.add(lw)
            out.append(w)
    return out

def classify_adr(text: str):
    print("πŸ” [DEBUG] Running classify_adr", flush=True)

    # Clean
    clean = text.strip().replace("nan", "").replace("  ", " ")
    print("πŸ” [DEBUG] clean[:50]:", clean[:50], "...", flush=True)

    # Severity
    inputs = tokenizer(clean, return_tensors="pt", truncation=True, padding=True, max_length=512)
    with torch.no_grad():
        logits = model(**inputs).logits
    probs = torch.softmax(logits, dim=1)[0].cpu().numpy()

    # Raw NER
    ents = ner(clean)
    print("πŸ” [DEBUG] raw ents:", [(e["entity_group"], e["word"], e["start"], e["end"]) for e in ents], flush=True)

    # 1) Build & merge spans by offsets
    spans = []
    for ent in ents:
        grp, start, end, score = ent["entity_group"].lower(), ent["start"], ent["end"], ent.get("score", 1.0)
        if spans and spans[-1]["group"] == grp and start <= spans[-1]["end"]:
            spans[-1]["end"]   = max(spans[-1]["end"], end)
            spans[-1]["score"] = max(spans[-1]["score"], score)
        else:
            spans.append({"group": grp, "start": start, "end": end, "score": score})
    print("πŸ” [DEBUG] merged spans:", spans, flush=True)

    # 2) Extend med spans out to full word
    for s in spans:
        if s["group"] in MED_TAGS:
            st, en = s["start"], s["end"]
            # extend forward while alphabetic
            while en < len(clean) and clean[en].isalpha():
                en += 1
            s["end"] = en

    # 3) Filter by confidence β‰₯0.6
    spans = [s for s in spans if s["score"] >= 0.6]
    print("πŸ” [DEBUG] post‑filter spans:", spans, flush=True)

    # 4) Extract text
    tokens = [clean[s["start"]:s["end"]] for s in spans]
    print("πŸ” [DEBUG] tokens:", tokens, flush=True)

    # Bucket & dedupe
    symptoms    = dedupe_and_filter([t for t, s in zip(tokens, spans) if s["group"] in SYMPTOM_TAGS])
    diseases    = dedupe_and_filter([t for t, s in zip(tokens, spans) if s["group"] in DISEASE_TAGS])
    medications = dedupe_and_filter([t for t, s in zip(tokens, spans) if s["group"] in MED_TAGS])

    # Interpretation
    if probs[1] > 0.9:
        comment = "❗ High confidence this is a severe ADR."
    elif probs[1] > 0.5:
        comment = "⚠️ Borderline case β€” may be severe."
    else:
        comment = "βœ… Likely not severe."

    return (
        f"Not Severe (0): {probs[0]:.3f}\nSevere (1): {probs[1]:.3f}",
        "\n".join(symptoms)    or "None detected",
        "\n".join(diseases)    or "None detected",
        "\n".join(medications) or "None detected",
        comment
    )

# 5) Gradio UI
demo = gr.Interface(
    fn=classify_adr,
    inputs=gr.Textbox(lines=4, label="ADR Description"),
    outputs=[
        gr.Textbox(label="Predicted Probabilities"),
        gr.Textbox(label="Symptoms"),
        gr.Textbox(label="Diseases or Conditions"),
        gr.Textbox(label="Medications"),
        gr.Textbox(label="Interpretation"),
    ],
    title="ADR Severity & NER Classifier",
    description="Paste an ADR description to classify severity and extract symptoms, diseases & medications.",
    allow_flagging="never"
)

if __name__ == "__main__":
    demo.launch()