File size: 2,057 Bytes
990c0de
 
0e57b8a
 
990c0de
 
 
18e08b6
 
990c0de
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
import gradio as gr
from transformers import AutoModelForSequenceClassification, AutoTokenizer
from evaluation import input_classification
from explainer import CustomExplainer
import numpy as np

checkpoint = "Detsutut/medbit-assertion-negation"
model = AutoModelForSequenceClassification.from_pretrained(checkpoint, num_labels=3)
tokenizer = AutoTokenizer.from_pretrained(checkpoint)
cls_explainer = CustomExplainer(model, tokenizer)
sentence = "Il paziente non mostra alcun segno di [entità]."


def compute(text):
    output = input_classification(model=model, tokenizer=tokenizer, x=text, all_classes=True)
    exp = cls_explainer(text)
    entities = []
    start = 0
    words_and_exp = cls_explainer.merge_attributions(exp)
    low_threshold = np.percentile([float(abs(w)) for _, w in words_and_exp], 25)
    high_threshold = np.percentile([float(abs(w)) for _, w in words_and_exp], 75)
    for i, entity in enumerate(words_and_exp):
        if entity[1] < 0 and entity[1] < low_threshold:
            polarity = "-"
            entities.append({"entity": polarity, "start": start, "end": start + len(entity[0])})
        elif entity[1] > 0 and entity[1] > high_threshold:
            polarity = "+"
            entities.append({"entity": polarity, "start": start, "end": start + len(entity[0])})
        start = start + len(entity[0]) + 1
    return output, gr.HighlightedText(label="Explanation", visible=True, color_map={"+": "green", "-": "red"},
                                      value={"text": " ".join([e[0] for e in words_and_exp]), "entities": entities},
                                      combine_adjacent=True, adjacent_separator=" ")


with gr.Blocks(title="Inference GUI") as gui:
    text = gr.Textbox(label="Input Text", value=sentence)
    explanation = gr.HighlightedText(label="Explanation", visible=False)
    output = gr.Label(label="Predicted Label", num_top_classes=3)
    compute_btn = gr.Button("Predict")
    compute_btn.click(fn=compute, inputs=text, outputs=[output, explanation], api_name="compute")

gui.launch()