Detsutut commited on
Commit
990c0de
1 Parent(s): 6fe70cc

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +42 -0
app.py ADDED
@@ -0,0 +1,42 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ from transformers import AutoModelForSequenceClassification, AutoTokenizer
3
+ from scripts.evaluation import input_classification
4
+ from scripts.explainer import CustomExplainer
5
+ import numpy as np
6
+
7
+ checkpoint = "Detsutut/medbit-assertion-negation"
8
+ model = AutoModelForSequenceClassification.from_pretrained(checkpoint, num_labels=3, local_files_only=True)
9
+ tokenizer = AutoTokenizer.from_pretrained(checkpoint, local_files_only=True)
10
+ cls_explainer = CustomExplainer(model, tokenizer)
11
+ sentence = "Il paziente non mostra alcun segno di [entità]."
12
+
13
+
14
+ def compute(text):
15
+ output = input_classification(model=model, tokenizer=tokenizer, x=text, all_classes=True)
16
+ exp = cls_explainer(text)
17
+ entities = []
18
+ start = 0
19
+ words_and_exp = cls_explainer.merge_attributions(exp)
20
+ low_threshold = np.percentile([float(abs(w)) for _, w in words_and_exp], 25)
21
+ high_threshold = np.percentile([float(abs(w)) for _, w in words_and_exp], 75)
22
+ for i, entity in enumerate(words_and_exp):
23
+ if entity[1] < 0 and entity[1] < low_threshold:
24
+ polarity = "-"
25
+ entities.append({"entity": polarity, "start": start, "end": start + len(entity[0])})
26
+ elif entity[1] > 0 and entity[1] > high_threshold:
27
+ polarity = "+"
28
+ entities.append({"entity": polarity, "start": start, "end": start + len(entity[0])})
29
+ start = start + len(entity[0]) + 1
30
+ return output, gr.HighlightedText(label="Explanation", visible=True, color_map={"+": "green", "-": "red"},
31
+ value={"text": " ".join([e[0] for e in words_and_exp]), "entities": entities},
32
+ combine_adjacent=True, adjacent_separator=" ")
33
+
34
+
35
+ with gr.Blocks(title="Inference GUI") as gui:
36
+ text = gr.Textbox(label="Input Text", value=sentence)
37
+ explanation = gr.HighlightedText(label="Explanation", visible=False)
38
+ output = gr.Label(label="Predicted Label", num_top_classes=3)
39
+ compute_btn = gr.Button("Predict")
40
+ compute_btn.click(fn=compute, inputs=text, outputs=[output, explanation], api_name="compute")
41
+
42
+ gui.launch()