|
import gradio as gr |
|
from transformers import AutoModelForSequenceClassification, AutoTokenizer |
|
from evaluation import input_classification |
|
from explainer import CustomExplainer |
|
import numpy as np |
|
|
|
checkpoint = "Detsutut/medbit-assertion-negation" |
|
model = AutoModelForSequenceClassification.from_pretrained(checkpoint, num_labels=3, local_files_only=True) |
|
tokenizer = AutoTokenizer.from_pretrained(checkpoint, local_files_only=True) |
|
cls_explainer = CustomExplainer(model, tokenizer) |
|
sentence = "Il paziente non mostra alcun segno di [entità]." |
|
|
|
|
|
def compute(text): |
|
output = input_classification(model=model, tokenizer=tokenizer, x=text, all_classes=True) |
|
exp = cls_explainer(text) |
|
entities = [] |
|
start = 0 |
|
words_and_exp = cls_explainer.merge_attributions(exp) |
|
low_threshold = np.percentile([float(abs(w)) for _, w in words_and_exp], 25) |
|
high_threshold = np.percentile([float(abs(w)) for _, w in words_and_exp], 75) |
|
for i, entity in enumerate(words_and_exp): |
|
if entity[1] < 0 and entity[1] < low_threshold: |
|
polarity = "-" |
|
entities.append({"entity": polarity, "start": start, "end": start + len(entity[0])}) |
|
elif entity[1] > 0 and entity[1] > high_threshold: |
|
polarity = "+" |
|
entities.append({"entity": polarity, "start": start, "end": start + len(entity[0])}) |
|
start = start + len(entity[0]) + 1 |
|
return output, gr.HighlightedText(label="Explanation", visible=True, color_map={"+": "green", "-": "red"}, |
|
value={"text": " ".join([e[0] for e in words_and_exp]), "entities": entities}, |
|
combine_adjacent=True, adjacent_separator=" ") |
|
|
|
|
|
with gr.Blocks(title="Inference GUI") as gui: |
|
text = gr.Textbox(label="Input Text", value=sentence) |
|
explanation = gr.HighlightedText(label="Explanation", visible=False) |
|
output = gr.Label(label="Predicted Label", num_top_classes=3) |
|
compute_btn = gr.Button("Predict") |
|
compute_btn.click(fn=compute, inputs=text, outputs=[output, explanation], api_name="compute") |
|
|
|
gui.launch() |