adr-detection / app.py
Christopher McMaster
Update app.py
079a915
raw
history blame
2.34 kB
import matplotlib.cm as cm
import html
import torch
import numpy as np
from transformers import pipeline
import gradio as gr
def value2rgba(x, cmap=cm.RdYlGn, alpha_mult=1.0):
"Convert a value `x` from 0 to 1 (inclusive) to an RGBA tuple according to `cmap` times transparency `alpha_mult`."
c = cmap(x)
rgb = (np.array(c[:-1]) * 255).astype(int)
a = c[-1] * alpha_mult
return tuple(rgb.tolist() + [a])
def piece_prob_html(pieces, prob, colors, label, sep=' ', **kwargs):
html_code,spans = ['<span style="font-family: monospace;">'], []
for p, a, cols, l in zip(pieces, prob, colors, label):
p = html.escape(p)
c = str(value2rgba(a, cmap=cols, alpha_mult=0.5, **kwargs))
spans.append(f'<span title="{l}: {a:.3f}" style="background-color: rgba{c};">{p}</span>')
html_code.append(sep.join(spans))
html_code.append('</span>')
return ''.join(html_code)
def nothing_ent(i, word):
return {
'entity': 'O',
'score': 0,
'index': i,
'word': word,
'start': 0,
'end': 0
}
def _gradio_highlighting(text):
result = ner_model(text)
tokens = ner_model.tokenizer.tokenize(text)
label_indeces = [i['index'] - 1 for i in result]
entities = list()
for i, word in enumerate(tokens):
if i in label_indeces:
entities.append(result[label_indeces.index(i)])
else:
entities.append(nothing_ent(i, word))
entities = ner_model.group_entities(entities)
spans = [e['word'] for e in entities]
probs = [e['score'] for e in entities]
labels = [e['entity_group'] for e in entities]
colors = [cm.RdPu if label == 'ADR' else cm.YlGn for i, label in enumerate(labels)]
return piece_prob_html(spans, probs, colors, labels, sep=' ')
default_text = """# Pancreatitis
- Lipase: 535 -> 154 -> 145
- Managed with NBM, IV fluids
- CT AP and abdo USS: normal
- Likely secondary to Azathioprine - ceased, never to be used again.
- Resolved with conservative measures
"""
ner_model = pipeline(task = 'token-classification', model = "austin/adr-ner")
iface = gr.Interface(_gradio_highlighting,
[
gr.inputs.Textbox(
lines=7,
label="Text",
default=default_text),
],
gr.outputs.HTML(label="ADR Prediction"),
)
iface.launch()