Create app.py
Browse files
app.py
ADDED
@@ -0,0 +1,42 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import gradio as gr
|
2 |
+
from transformers import AutoModelForSequenceClassification, AutoTokenizer
|
3 |
+
from scripts.evaluation import input_classification
|
4 |
+
from scripts.explainer import CustomExplainer
|
5 |
+
import numpy as np
|
6 |
+
|
7 |
+
checkpoint = "Detsutut/medbit-assertion-negation"
|
8 |
+
model = AutoModelForSequenceClassification.from_pretrained(checkpoint, num_labels=3, local_files_only=True)
|
9 |
+
tokenizer = AutoTokenizer.from_pretrained(checkpoint, local_files_only=True)
|
10 |
+
cls_explainer = CustomExplainer(model, tokenizer)
|
11 |
+
sentence = "Il paziente non mostra alcun segno di [entità]."
|
12 |
+
|
13 |
+
|
14 |
+
def compute(text):
|
15 |
+
output = input_classification(model=model, tokenizer=tokenizer, x=text, all_classes=True)
|
16 |
+
exp = cls_explainer(text)
|
17 |
+
entities = []
|
18 |
+
start = 0
|
19 |
+
words_and_exp = cls_explainer.merge_attributions(exp)
|
20 |
+
low_threshold = np.percentile([float(abs(w)) for _, w in words_and_exp], 25)
|
21 |
+
high_threshold = np.percentile([float(abs(w)) for _, w in words_and_exp], 75)
|
22 |
+
for i, entity in enumerate(words_and_exp):
|
23 |
+
if entity[1] < 0 and entity[1] < low_threshold:
|
24 |
+
polarity = "-"
|
25 |
+
entities.append({"entity": polarity, "start": start, "end": start + len(entity[0])})
|
26 |
+
elif entity[1] > 0 and entity[1] > high_threshold:
|
27 |
+
polarity = "+"
|
28 |
+
entities.append({"entity": polarity, "start": start, "end": start + len(entity[0])})
|
29 |
+
start = start + len(entity[0]) + 1
|
30 |
+
return output, gr.HighlightedText(label="Explanation", visible=True, color_map={"+": "green", "-": "red"},
|
31 |
+
value={"text": " ".join([e[0] for e in words_and_exp]), "entities": entities},
|
32 |
+
combine_adjacent=True, adjacent_separator=" ")
|
33 |
+
|
34 |
+
|
35 |
+
with gr.Blocks(title="Inference GUI") as gui:
|
36 |
+
text = gr.Textbox(label="Input Text", value=sentence)
|
37 |
+
explanation = gr.HighlightedText(label="Explanation", visible=False)
|
38 |
+
output = gr.Label(label="Predicted Label", num_top_classes=3)
|
39 |
+
compute_btn = gr.Button("Predict")
|
40 |
+
compute_btn.click(fn=compute, inputs=text, outputs=[output, explanation], api_name="compute")
|
41 |
+
|
42 |
+
gui.launch()
|