adr-detection / app.py
Christopher McMaster
Create app.py
737007d
raw
history blame
2.66 kB
import matplotlib.cm as cm
import html
import torch
import numpy as np
from transformers import pipeline
import gradio as gr
def value2rgba(x, cmap=cm.RdYlGn, alpha_mult=1.0):
"Convert a value `x` from 0 to 1 (inclusive) to an RGBA tuple according to `cmap` times transparency `alpha_mult`."
c = cmap(x)
rgb = (np.array(c[:-1]) * 255).astype(int)
a = c[-1] * alpha_mult
return tuple(rgb.tolist() + [a])
def piece_prob_html(pieces, prob, colors, label, sep=' ', **kwargs):
html_code,spans = ['<span style="font-family: monospace;">'], []
for p, a, cols, l in zip(pieces, prob, colors, label):
p = html.escape(p)
c = str(value2rgba(a, cmap=cols, alpha_mult=0.5, **kwargs))
spans.append(f'<span title="{l}: {a:.3f}" style="background-color: rgba{c};">{p}</span>')
html_code.append(sep.join(spans))
html_code.append('</span>')
return ''.join(html_code)
def nothing_ent(i, word):
return {
'entity': 'O',
'score': 0,
'index': i,
'word': word,
'start': 0,
'end': 0
}
def _gradio_highlighting(text):
result = ner_model(text)
tokens = ner_model.tokenizer.tokenize(text)
label_indeces = [i['index'] - 1 for i in result]
entities = list()
for i, word in enumerate(tokens):
if i in label_indeces:
entities.append(result[label_indeces.index(i)])
else:
entities.append(nothing_ent(i, word))
entities = ner_model.group_entities(entities)
spans = [e['word'] for e in entities]
probs = [e['score'] for e in entities]
labels = [e['entity_group'] for e in entities]
colors = [cm.RdPu if label == 'ADR' else cm.YlGn for i, label in enumerate(labels)]
return piece_prob_html(spans, probs, colors, labels, sep=' ')
default_text = """# Pancreatitis
- Lipase: 535 -> 154 -> 145
- Managed with NBM, IV fluids
- CT AP and abdo USS: normal
- Likely secondary to Azathioprine - ceased, never to be used again.
- Resolved with conservative measures
"""
title = "Adverse Drug Reaction Highlighting"
description = "This app was made to accompany our recent paper. ADRs will be highlighted in purple, offending medications in green. Hover over a word to see the strength of each prediction on a 0-1 scale."
ner_model = pipeline(task = 'token-classification', model = "austin/adr-ner")
iface = gr.Interface(_gradio_highlighting,
[
gr.inputs.Textbox(
lines=7,
label="Text",
default=default_text),
],
gr.outputs.HTML(label="ADR Prediction"),
title = title,
description = description,
theme = "darkdefault")
iface.launch()