import matplotlib.cm as cm import html import torch import numpy as np from transformers import pipeline import gradio as gr def value2rgba(x, cmap=cm.RdYlGn, alpha_mult=1.0): "Convert a value `x` from 0 to 1 (inclusive) to an RGBA tuple according to `cmap` times transparency `alpha_mult`." c = cmap(x) rgb = (np.array(c[:-1]) * 255).astype(int) a = c[-1] * alpha_mult return tuple(rgb.tolist() + [a]) def piece_prob_html(pieces, prob, colors, label, sep=' ', **kwargs): html_code,spans = [''], [] for p, a, cols, l in zip(pieces, prob, colors, label): p = html.escape(p) c = str(value2rgba(a, cmap=cols, alpha_mult=0.5, **kwargs)) spans.append(f'{p}') html_code.append(sep.join(spans)) html_code.append('') return ''.join(html_code) def nothing_ent(i, word): return { 'entity': 'O', 'score': 0, 'index': i, 'word': word, 'start': 0, 'end': 0 } def _gradio_highlighting(text): result = ner_model(text) tokens = ner_model.tokenizer.tokenize(text) label_indeces = [i['index'] - 1 for i in result] entities = list() for i, word in enumerate(tokens): if i in label_indeces: entities.append(result[label_indeces.index(i)]) else: entities.append(nothing_ent(i, word)) entities = ner_model.group_entities(entities) spans = [e['word'] for e in entities] probs = [e['score'] for e in entities] labels = [e['entity_group'] for e in entities] colors = [cm.RdPu if label == 'ADR' else cm.YlGn for i, label in enumerate(labels)] return piece_prob_html(spans, probs, colors, labels, sep=' ') default_text = """# Pancreatitis - Lipase: 535 -> 154 -> 145 - Managed with NBM, IV fluids - CT AP and abdo USS: normal - Likely secondary to Azathioprine - ceased, never to be used again. - Resolved with conservative measures """ title = "Adverse Drug Reaction Highlighting" description = "Named Entity Recognition model to detect ADRs in discharge summaries" article = """This app was made to accompany our recent [paper](https://www.medrxiv.org/content/10.1101/2021.12.11.21267504v2).
ADRs will be highlighted in purple, offending medications in green.
Hover over a word to see the strength of each prediction on a 0-1 scale.
Our training code can be found at [github](https://github.com/AustinMOS/adr-nlp). """ ner_model = pipeline(task = 'token-classification', model = "austin/adr-ner") iface = gr.Interface(_gradio_highlighting, [ gr.inputs.Textbox( lines=7, label="Text", default=default_text), ], gr.outputs.HTML(label="ADR Prediction"), title = title, description = description, article = article, theme = "huggingface" ) iface.launch()