adr-detection / app.py
Christopher McMaster
Update app.py
5fa3012
import matplotlib.cm as cm
import html
import torch
import numpy as np
from transformers import pipeline
import gradio as gr
def value2rgba(x, cmap=cm.RdYlGn, alpha_mult=1.0):
"Convert a value `x` from 0 to 1 (inclusive) to an RGBA tuple according to `cmap` times transparency `alpha_mult`."
c = cmap(x)
rgb = (np.array(c[:-1]) * 255).astype(int)
a = c[-1] * alpha_mult
return tuple(rgb.tolist() + [a])
def piece_prob_html(pieces, prob, colors, label, sep=' ', **kwargs):
html_code,spans = ['<span style="font-family: monospace;">'], []
for p, a, cols, l in zip(pieces, prob, colors, label):
p = html.escape(p)
c = str(value2rgba(a, cmap=cols, alpha_mult=0.5, **kwargs))
spans.append(f'<span title="{l}: {a:.3f}" style="background-color: rgba{c};">{p}</span>')
html_code.append(sep.join(spans))
html_code.append('</span>')
return ''.join(html_code)
def nothing_ent(i, word):
return {
'entity': 'O',
'score': 0,
'index': i,
'word': word,
'start': 0,
'end': 0
}
def _gradio_highlighting(text):
result = ner_model(text)
tokens = ner_model.tokenizer.tokenize(text)
label_indeces = [i['index'] - 1 for i in result]
entities = list()
for i, word in enumerate(tokens):
if i in label_indeces:
entities.append(result[label_indeces.index(i)])
else:
entities.append(nothing_ent(i, word))
entities = ner_model.group_entities(entities)
spans = [e['word'] for e in entities]
probs = [e['score'] for e in entities]
labels = [e['entity_group'] for e in entities]
colors = [cm.RdPu if label == 'ADR' else cm.YlGn for i, label in enumerate(labels)]
return piece_prob_html(spans, probs, colors, labels, sep=' ')
default_text = """# Pancreatitis
- Lipase: 535 -> 154 -> 145
- Managed with NBM, IV fluids
- CT AP and abdo USS: normal
- Likely secondary to Azathioprine - ceased, never to be used again.
- Resolved with conservative measures
"""
title = "Adverse Drug Reaction Highlighting"
description = "Named Entity Recognition model to detect ADRs in discharge summaries"
article = """This app was made to accompany our recent [paper](https://www.medrxiv.org/content/10.1101/2021.12.11.21267504v2).<br>
ADRs will be highlighted in <span style="color:purple">purple</span>, offending medications in <span style="color:green">green</span>.<br>
Hover over a word to see the strength of each prediction on a 0-1 scale.<br>
Our training code can be found at [github](https://github.com/AustinMOS/adr-nlp).
"""
ner_model = pipeline(task = 'token-classification', model = "austin/adr-ner")
iface = gr.Interface(_gradio_highlighting,
[
gr.inputs.Textbox(
lines=7,
label="Text",
default=default_text),
],
gr.outputs.HTML(label="ADR Prediction"),
title = title,
description = description,
article = article,
theme = "huggingface"
)
iface.launch()