File size: 2,999 Bytes
737007d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
079a915
d12696e
1b37606
5fa3012
 
 
41b0f29
1b6eade
d12696e
737007d
 
 
 
 
 
 
 
 
d12696e
 
1b37606
1b6eade
079a915
737007d
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
import matplotlib.cm as cm
import html
import torch
import numpy as np
from transformers import pipeline
import gradio as gr
def value2rgba(x, cmap=cm.RdYlGn, alpha_mult=1.0):
    "Convert a value `x` from 0 to 1 (inclusive) to an RGBA tuple according to `cmap` times transparency `alpha_mult`."
    c = cmap(x)
    rgb = (np.array(c[:-1]) * 255).astype(int)
    a = c[-1] * alpha_mult
    return tuple(rgb.tolist() + [a])
def piece_prob_html(pieces, prob, colors, label, sep=' ', **kwargs):
    html_code,spans = ['<span style="font-family: monospace;">'], []
    for p, a, cols, l in zip(pieces, prob, colors, label):
        p = html.escape(p)
        c = str(value2rgba(a, cmap=cols, alpha_mult=0.5, **kwargs))
        spans.append(f'<span title="{l}: {a:.3f}" style="background-color: rgba{c};">{p}</span>')
    html_code.append(sep.join(spans))
    html_code.append('</span>')
    return ''.join(html_code)
def nothing_ent(i, word):
    return {
    'entity': 'O',
    'score': 0,
    'index': i,
    'word': word,
    'start': 0,
    'end': 0
}
def _gradio_highlighting(text):
    result = ner_model(text)
    tokens = ner_model.tokenizer.tokenize(text)
    label_indeces = [i['index'] - 1 for i in result]
    entities = list()
    for i, word in enumerate(tokens):
        if i in label_indeces:
            entities.append(result[label_indeces.index(i)])
        else:
            entities.append(nothing_ent(i, word))
    entities = ner_model.group_entities(entities)
    spans = [e['word'] for e in entities]
    probs = [e['score'] for e in entities]
    labels = [e['entity_group'] for e in entities]
    colors = [cm.RdPu if label == 'ADR' else cm.YlGn for i, label in enumerate(labels)]
    return piece_prob_html(spans, probs, colors, labels, sep=' ')
    
default_text = """# Pancreatitis
    - Lipase: 535 -> 154 -> 145
    - Managed with NBM, IV fluids
    - CT AP and abdo USS: normal
    - Likely secondary to Azathioprine - ceased, never to be used again.
    - Resolved with conservative measures 
"""

title = "Adverse Drug Reaction Highlighting"
description = "Named Entity Recognition model to detect ADRs in discharge summaries"
article = """This app was made to accompany our recent [paper](https://www.medrxiv.org/content/10.1101/2021.12.11.21267504v2).<br>
ADRs will be highlighted in <span style="color:purple">purple</span>, offending medications in <span style="color:green">green</span>.<br>
Hover over a word to see the strength of each prediction on a 0-1 scale.<br>
Our training code can be found at [github](https://github.com/AustinMOS/adr-nlp).
"""

ner_model = pipeline(task = 'token-classification', model = "austin/adr-ner")
iface = gr.Interface(_gradio_highlighting, 
    [
        gr.inputs.Textbox(
            lines=7, 
            label="Text", 
            default=default_text), 
    ], 
    gr.outputs.HTML(label="ADR Prediction"),
    title = title,
    description = description,
    article = article,
    theme = "huggingface"
)
iface.launch()