adr-detection / app.py
Christopher McMaster
Update app.py
9daccf3
raw history blame
No virus
2.99 kB
import matplotlib.cm as cm
import html
import torch
import numpy as np
from transformers import pipeline
import gradio as gr
def value2rgba(x, cmap=cm.RdYlGn, alpha_mult=1.0):
"Convert a value `x` from 0 to 1 (inclusive) to an RGBA tuple according to `cmap` times transparency `alpha_mult`."
c = cmap(x)
rgb = (np.array(c[:-1]) * 255).astype(int)
a = c[-1] * alpha_mult
return tuple(rgb.tolist() + [a])
def piece_prob_html(pieces, prob, colors, label, sep=' ', **kwargs):
html_code,spans = ['<span style="font-family: monospace;">'], []
for p, a, cols, l in zip(pieces, prob, colors, label):
p = html.escape(p)
c = str(value2rgba(a, cmap=cols, alpha_mult=0.5, **kwargs))
spans.append(f'<span title="{l}: {a:.3f}" style="background-color: rgba{c};">{p}</span>')
html_code.append(sep.join(spans))
html_code.append('</span>')
return ''.join(html_code)
def nothing_ent(i, word):
return {
'entity': 'O',
'score': 0,
'index': i,
'word': word,
'start': 0,
'end': 0
}
def _gradio_highlighting(text):
result = ner_model(text)
tokens = ner_model.tokenizer.tokenize(text)
label_indeces = [i['index'] - 1 for i in result]
entities = list()
for i, word in enumerate(tokens):
if i in label_indeces:
entities.append(result[label_indeces.index(i)])
else:
entities.append(nothing_ent(i, word))
entities = ner_model.group_entities(entities)
spans = [e['word'] for e in entities]
probs = [e['score'] for e in entities]
labels = [e['entity_group'] for e in entities]
colors = [cm.RdPu if label == 'ADR' else cm.YlGn for i, label in enumerate(labels)]
return piece_prob_html(spans, probs, colors, labels, sep=' ')
default_text = """# Pancreatitis
- Lipase: 535 -> 154 -> 145
- Managed with NBM, IV fluids
- CT AP and abdo USS: normal
- Likely secondary to Azathioprine - ceased, never to be used again.
- Resolved with conservative measures
"""
title = "Adverse Drug Reaction Highlighting"
description = "Named Entity Recognition model to detect ADRs in discharge summaries"
article = """This app was made to accompany our recent [paper](https://www.medrxiv.org/content/10.1101/2021.12.11.21267504v2).
ADRs will be highlighted in <span style="color:purple">purple</span>, offending medications in <span style="color:green">green</span>.
Hover over a word to see the strength of each prediction on a 0-1 scale.
Our training code can be found at [github](https://github.com/AustinMOS/adr-nlp).
"""
ner_model = pipeline(task = 'token-classification', model = "austin/adr-ner")
iface = gr.Interface(_gradio_highlighting,
[
gr.inputs.Textbox(
lines=7,
label="Text",
default=default_text),
],
gr.outputs.HTML(label="ADR Prediction"),
title = title,
description = description,
article = article,
theme = "huggingface"
)
iface.launch()