awacke1's picture
Create app.py
d3cf082
raw history blame
No virus
2.05 kB
import json
from collections import defaultdict
import matplotlib.pyplot as plt
import gradio as gr
import pandas as pd
from transformers import pipeline, AutoTokenizer, AutoModelForTokenClassification
tokenizer = AutoTokenizer.from_pretrained("d4data/biomedical-ner-all")
model = AutoModelForTokenClassification.from_pretrained("d4data/biomedical-ner-all")
plt.switch_backend("Agg")
EXAMPLE_MAP = {}
with open("examples.json", "r") as f:
example_json = json.load(f)
EXAMPLE_MAP = {x["text"]: x["label"] for x in example_json}
pipe = pipeline("ner", model=model, tokenizer=tokenizer, aggregation_strategy="simple")
def group_by_entity(raw):
out = defaultdict(int)
for ent in raw:
out[ent["entity_group"]] += 1
# out["total"] = sum(out.values())
return out
def plot_to_figure(grouped):
fig = plt.figure()
plt.bar(x=list(grouped.keys()), height=list(grouped.values()))
plt.margins(0.2)
plt.subplots_adjust(bottom=0.4)
plt.xticks(rotation=90)
return fig
def ner(text):
raw = pipe(text)
ner_content = {
"text": text,
"entities": [
{
"entity": x["entity_group"],
"word": x["word"],
"score": x["score"],
"start": x["start"],
"end": x["end"],
}
for x in raw
],
}
grouped = group_by_entity(raw)
figure = plot_to_figure(grouped)
label = EXAMPLE_MAP.get(text, "Unknown")
meta = {
"entity_counts": grouped,
"entities": len(set(grouped.keys())),
"counts": sum(grouped.values()),
}
return (ner_content, meta, label, figure)
interface = gr.Interface(
ner,
inputs=gr.Textbox(label="Note text", value=""),
outputs=[
gr.HighlightedText(label="NER", combine_adjacent=True),
gr.JSON(label="Entity Counts"),
gr.Label(label="Rating"),
gr.Plot(label="Bar"),
],
examples=list(EXAMPLE_MAP.keys()),
allow_flagging="never",
)
interface.launch()