intent-analysis / app.py
youj2005's picture
Made improvements and changed base mnli model
36b3b29
raw
history blame
No virus
4.32 kB
import gradio as gr
from transformers import AutoTokenizer, AutoModelForSequenceClassification
from transformers import T5Tokenizer, T5ForConditionalGeneration
import torch
device = torch.device("mps" if torch.backends.mps.is_available() else "cpu")
te_tokenizer = AutoTokenizer.from_pretrained('MoritzLaurer/DeBERTa-v3-base-mnli-fever-anli')
te_model = AutoModelForSequenceClassification.from_pretrained('MoritzLaurer/DeBERTa-v3-base-mnli-fever-anli').to(device)
qa_tokenizer = T5Tokenizer.from_pretrained("google/flan-t5-large")
qa_model = T5ForConditionalGeneration.from_pretrained("google/flan-t5-large", device_map="auto")
def predict(context, intent, multi_class):
input_text = "What is the opposite of " + intent + "?"
input_ids = qa_tokenizer(input_text, return_tensors="pt").input_ids.to(device)
opposite_output = qa_tokenizer.decode(qa_model.generate(input_ids, max_length=2)[0], skip_special_tokens=True)
input_text = "What object is the following describing: " + context
input_ids = qa_tokenizer(input_text, return_tensors="pt").input_ids.to(device)
object_output = qa_tokenizer.decode(qa_model.generate(input_ids, max_length=2)[0], skip_special_tokens=True)
batch = ['The ' + object_output + ' is ' + intent, 'The ' + object_output + ' is ' + opposite_output, 'The ' + object_output + ' is neither ' + intent + ' nor ' + opposite_output]
outputs = []
normal = 0
for i, hypothesis in enumerate(batch):
input_ids = te_tokenizer.encode(context, hypothesis, return_tensors='pt').to(device)
# -> [contradiction, neutral, entailment]
logits = te_model(input_ids)[0][0]
if (i == 0):
normal = logits
if (i >= 2):
# -> [contradiction, entailment]
probs = logits[[0,2]].softmax(dim=0)
else:
probs = torch.exp(logits)
outputs.append(probs)
# calculate the stochastic vector for it being neither the positive or negative class
perfect_prob = outputs[2]
# -> [entailment, contradiction] for perfect
# -> [entailment, neutral, contradiction] for positive
outputs[1] = outputs[1].flip(dims=[0])
# combine the negative and positive class by summing by the opposite of the negative class
aggregated = (outputs[0] + outputs[1])/2
# multiplying vectors
aggregated[1] = aggregated[1] * perfect_prob[0]
aggregated[0] = aggregated[0] * perfect_prob[1]
aggregated[2] = aggregated[2] * perfect_prob[1]
aggregated = torch.sqrt(aggregated)
# multiple true classes
if (multi_class):
aggregated = torch.sigmoid(aggregated)
normal = torch.sigmoid(normal)
# only one true class
else:
aggregated = aggregated.softmax(dim=0)
normal = normal.softmax(dim=0)
aggregated = aggregated.tolist()
return {"agree": aggregated[0], "neutral": aggregated[1], "disagree": aggregated[2]}, {"agree": normal[0], "neutral": normal[1], "disagree": normal[2]}
examples = [["These are my absolute favorite cargos in my closet. I’m 5’7 and they’re actually long enough for me. I’m 165lbs and ordered an M & it fits nice and loose just how I wanted it. The adjustable waist band is awesome!", "long"], ["I feel strongly about politics in the US", "long"], ["The pants are long", "long"], ["The pants are slightly long", "long"]]
gradio_app = gr.Interface(
predict,
examples=examples,
inputs=[gr.Text(label="Statement"), gr.Text(label="Class"), gr.Checkbox(label="Allow multiple true classes")],
outputs=[gr.Label(num_top_classes=3, label="With Postprocessing"), gr.Label(num_top_classes=3, label="Without Postprocessing")],
title="Intent Analysis",
description="This model predicts whether or not the **_class_** describes the **_object described in the sentence_**. <br /> The two outputs shows what TE would predict with and without the postprocessing. An example edge case for normal TE is shown below. <br /> **_It is recommended that you clone the repository to speed up processing time_**. <br /> Additionally, note the difference between the strength of the probability when going between the last two examples, the former representing a strong opinion and the latter a weaker opinion",
cache_examples=True
)
gradio_app.launch()