awacke1 commited on
Commit
d3cf082
1 Parent(s): d2a2ae3

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +80 -0
app.py ADDED
@@ -0,0 +1,80 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ from collections import defaultdict
3
+
4
+ import matplotlib.pyplot as plt
5
+ import gradio as gr
6
+ import pandas as pd
7
+ from transformers import pipeline, AutoTokenizer, AutoModelForTokenClassification
8
+
9
+ tokenizer = AutoTokenizer.from_pretrained("d4data/biomedical-ner-all")
10
+ model = AutoModelForTokenClassification.from_pretrained("d4data/biomedical-ner-all")
11
+
12
+ plt.switch_backend("Agg")
13
+
14
+ EXAMPLE_MAP = {}
15
+ with open("examples.json", "r") as f:
16
+ example_json = json.load(f)
17
+ EXAMPLE_MAP = {x["text"]: x["label"] for x in example_json}
18
+
19
+ pipe = pipeline("ner", model=model, tokenizer=tokenizer, aggregation_strategy="simple")
20
+
21
+
22
+ def group_by_entity(raw):
23
+ out = defaultdict(int)
24
+ for ent in raw:
25
+ out[ent["entity_group"]] += 1
26
+ # out["total"] = sum(out.values())
27
+ return out
28
+
29
+
30
+ def plot_to_figure(grouped):
31
+ fig = plt.figure()
32
+ plt.bar(x=list(grouped.keys()), height=list(grouped.values()))
33
+ plt.margins(0.2)
34
+ plt.subplots_adjust(bottom=0.4)
35
+ plt.xticks(rotation=90)
36
+ return fig
37
+
38
+
39
+ def ner(text):
40
+ raw = pipe(text)
41
+ ner_content = {
42
+ "text": text,
43
+ "entities": [
44
+ {
45
+ "entity": x["entity_group"],
46
+ "word": x["word"],
47
+ "score": x["score"],
48
+ "start": x["start"],
49
+ "end": x["end"],
50
+ }
51
+ for x in raw
52
+ ],
53
+ }
54
+ grouped = group_by_entity(raw)
55
+ figure = plot_to_figure(grouped)
56
+ label = EXAMPLE_MAP.get(text, "Unknown")
57
+
58
+ meta = {
59
+ "entity_counts": grouped,
60
+ "entities": len(set(grouped.keys())),
61
+ "counts": sum(grouped.values()),
62
+ }
63
+
64
+ return (ner_content, meta, label, figure)
65
+
66
+
67
+ interface = gr.Interface(
68
+ ner,
69
+ inputs=gr.Textbox(label="Note text", value=""),
70
+ outputs=[
71
+ gr.HighlightedText(label="NER", combine_adjacent=True),
72
+ gr.JSON(label="Entity Counts"),
73
+ gr.Label(label="Rating"),
74
+ gr.Plot(label="Bar"),
75
+ ],
76
+ examples=list(EXAMPLE_MAP.keys()),
77
+ allow_flagging="never",
78
+ )
79
+
80
+ interface.launch()