andrewgleave commited on
Commit
023a91a
1 Parent(s): ecca0d1
Files changed (2) hide show
  1. app.py +24 -39
  2. requirements.txt +3 -3
app.py CHANGED
@@ -1,5 +1,5 @@
1
  import json
2
- from collections import defaultdict
3
 
4
  import matplotlib.pyplot as plt
5
  import gradio as gr
@@ -11,22 +11,14 @@ model = AutoModelForTokenClassification.from_pretrained("d4data/biomedical-ner-a
11
 
12
  plt.switch_backend("Agg")
13
 
14
- EXAMPLE_MAP = {}
15
  with open("examples.json", "r") as f:
16
- example_json = json.load(f)
17
- EXAMPLE_MAP = {x["text"]: x["label"] for x in example_json}
18
 
19
  pipe = pipeline("ner", model=model, tokenizer=tokenizer, aggregation_strategy="simple")
20
 
21
 
22
- def group_by_entity(raw):
23
- out = defaultdict(int)
24
- for ent in raw:
25
- out[ent["entity_group"]] += 1
26
- # out["total"] = sum(out.values())
27
- return out
28
-
29
-
30
  def plot_to_figure(grouped):
31
  fig = plt.figure()
32
  plt.bar(x=list(grouped.keys()), height=list(grouped.values()))
@@ -36,7 +28,7 @@ def plot_to_figure(grouped):
36
  return fig
37
 
38
 
39
- def ner(text):
40
  raw = pipe(text)
41
  ner_content = {
42
  "text": text,
@@ -51,30 +43,23 @@ def ner(text):
51
  for x in raw
52
  ],
53
  }
54
- grouped = group_by_entity(raw)
55
- figure = plot_to_figure(grouped)
56
- label = EXAMPLE_MAP.get(text, "Unknown")
57
-
58
- meta = {
59
- "entity_counts": grouped,
60
- "entities": len(set(grouped.keys())),
61
- "counts": sum(grouped.values()),
62
- }
63
 
64
- return (ner_content, meta, label, figure)
65
-
66
-
67
- interface = gr.Interface(
68
- ner,
69
- inputs=gr.Textbox(label="Note text", value=""),
70
- outputs=[
71
- gr.HighlightedText(label="NER", combine_adjacent=True),
72
- gr.JSON(label="Entity Counts"),
73
- gr.Label(label="Rating"),
74
- gr.Plot(label="Bar"),
75
- ],
76
- examples=list(EXAMPLE_MAP.keys()),
77
- allow_flagging="never",
78
- )
79
-
80
- interface.launch()
 
 
 
1
  import json
2
+ from collections import defaultdict, Counter
3
 
4
  import matplotlib.pyplot as plt
5
  import gradio as gr
 
11
 
12
  plt.switch_backend("Agg")
13
 
14
+ examples = {}
15
  with open("examples.json", "r") as f:
16
+ content = json.load(f)
17
+ examples = {x["text"]: x["label"] for x in content}
18
 
19
  pipe = pipeline("ner", model=model, tokenizer=tokenizer, aggregation_strategy="simple")
20
 
21
 
 
 
 
 
 
 
 
 
22
  def plot_to_figure(grouped):
23
  fig = plt.figure()
24
  plt.bar(x=list(grouped.keys()), height=list(grouped.values()))
 
28
  return fig
29
 
30
 
31
+ def run_ner(text):
32
  raw = pipe(text)
33
  ner_content = {
34
  "text": text,
 
43
  for x in raw
44
  ],
45
  }
 
 
 
 
 
 
 
 
 
46
 
47
+ grouped = Counter((x["entity_group"] for x in raw))
48
+ rows = [[k, v] for k, v in grouped.items()]
49
+ figure = plot_to_figure(grouped)
50
+ return ner_content, rows, figure
51
+
52
+
53
+ with gr.Blocks() as demo:
54
+ note = gr.Textbox(label="Note text")
55
+ with gr.Accordion("Examples", open=False):
56
+ examples = gr.Examples(examples=list(examples.keys()), inputs=note)
57
+ with gr.Tab("NER"):
58
+ highlight = gr.HighlightedText(label="NER", combine_adjacent=True)
59
+ with gr.Tab("Bar"):
60
+ plot = gr.Plot(label="Bar")
61
+ with gr.Tab("Table"):
62
+ table = gr.Dataframe(headers=["Entity", "Count"])
63
+ note.submit(run_ner, [note], [highlight, table, plot])
64
+
65
+ demo.launch()
requirements.txt CHANGED
@@ -18,7 +18,7 @@ filelock==3.8.0
18
  fonttools==4.37.4
19
  frozenlist==1.3.1
20
  fsspec==2022.8.2
21
- gradio==3.4.1
22
  h11==0.12.0
23
  httpcore==0.15.0
24
  httpx==0.23.0
@@ -60,9 +60,9 @@ sniffio==1.3.0
60
  starlette==0.20.4
61
  tokenizers==0.12.1
62
  tomli==2.0.1
63
- torch==1.12.1
64
  tqdm==4.64.1
65
- transformers==4.22.2
66
  typing_extensions==4.4.0
67
  uc-micro-py==1.0.1
68
  urllib3==1.26.12
 
18
  fonttools==4.37.4
19
  frozenlist==1.3.1
20
  fsspec==2022.8.2
21
+ gradio==3.11.0
22
  h11==0.12.0
23
  httpcore==0.15.0
24
  httpx==0.23.0
 
60
  starlette==0.20.4
61
  tokenizers==0.12.1
62
  tomli==2.0.1
63
+ torch==1.13.0
64
  tqdm==4.64.1
65
+ transformers==4.24.0
66
  typing_extensions==4.4.0
67
  uc-micro-py==1.0.1
68
  urllib3==1.26.12