intent-analysis / app.py
youj2005's picture
Switched qa model
8cb81df
raw
history blame contribute delete
No virus
4.42 kB
import gradio as gr
from transformers import AutoTokenizer, AutoModelForSequenceClassification
from transformers import T5Tokenizer, T5ForConditionalGeneration
from transformers import pipeline
import torch
device = torch.device("mps" if torch.backends.mps.is_available() else "cpu")
te_tokenizer = AutoTokenizer.from_pretrained('MoritzLaurer/DeBERTa-v3-base-mnli-fever-anli')
te_model = AutoModelForSequenceClassification.from_pretrained('MoritzLaurer/DeBERTa-v3-base-mnli-fever-anli').to(device)
qa_pipeline = pipeline("question-answering", model='distilbert/distilbert-base-cased-distilled-squad')
qa_tokenizer = T5Tokenizer.from_pretrained("google/flan-t5-base")
qa_model = T5ForConditionalGeneration.from_pretrained("google/flan-t5-base", device_map="auto")
def predict(context, intent, multi_class):
print(context, intent)
input_text = "What is the opposite of " + intent + "?"
input_ids = qa_tokenizer(input_text, return_tensors="pt").input_ids.to(device)
opposite_output = qa_tokenizer.decode(qa_model.generate(input_ids, max_length=2)[0], skip_special_tokens=True)
input_text = "What object/thing is being described in the entire sentence?"
object_output = qa_pipeline(question=input_text, context=context, max_answer_len=2)['answer']
batch = ['The ' + object_output + ' is ' + intent, 'The ' + object_output + ' is ' + opposite_output, 'The ' + object_output + ' is neither ' + intent + ' nor ' + opposite_output]
outputs = []
normal = 0
print(batch)
for i, hypothesis in enumerate(batch):
input_ids = te_tokenizer.encode(context, hypothesis, return_tensors='pt').to(device)
# -> [contradiction, neutral, entailment]
logits = te_model(input_ids)[0][0]
if (i == 0):
normal = logits
if (i >= 2):
# -> [contradiction, entailment]
probs = logits[[0,2]].softmax(dim=0)
else:
probs = torch.exp(logits)
outputs.append(probs)
# calculate the stochastic vector for it being neither the positive or negative class
perfect_prob = outputs[2]
# -> [entailment, contradiction] for perfect
# -> [entailment, neutral, contradiction] for positive
outputs[1] = outputs[1].flip(dims=[0])
print(outputs)
print(perfect_prob)
# combine the negative and positive class by summing by the opposite of the negative class
aggregated = (outputs[0]+outputs[1])/2
print(aggregated)
# multiplying vectors
aggregated[1] = aggregated[1] + perfect_prob[0]
aggregated[0] = aggregated[0] * perfect_prob[1]
aggregated[2] = aggregated[2] * perfect_prob[1]
# multiple true classes
if (multi_class):
aggregated = torch.sigmoid(aggregated)
normal = torch.sigmoid(normal)
# only one true class
else:
aggregated = aggregated.softmax(dim=0)
normal = normal.softmax(dim=0)
return {"agree": aggregated[0], "neutral": aggregated[1], "disagree": aggregated[2]}, {"agree": normal[0], "neutral": normal[1], "disagree": normal[2]}
examples = [["These are so warm and comfortable. I’m 5’7”, 140 lbs, size 6-8 and Medium is a great fit. They wash and dry nicely too. The jogger style is the only style I can wear in this brand - the others are way too long so I had to return.", "long"], ["I feel strongly about politics in the US", "long"], ["The pants are long", "long"], ["The pants are slightly long", "long"]]
gradio_app = gr.Interface(
predict,
examples=examples,
inputs=[gr.Text(label="Statement"), gr.Text(label="Class"), gr.Checkbox(label="Allow multiple true classes")],
outputs=[gr.Label(num_top_classes=3, label="With Postprocessing"), gr.Label(num_top_classes=3, label="Without Postprocessing")],
title="Intent Analysis",
description="This model predicts whether or not the **_class_** describes the **_object described in the sentence_**. <br /> The two outputs shows what TE would predict with and without the postprocessing. An example edge case for normal TE is shown below. <br /> **_It is recommended that you clone the repository to speed up processing time_**. <br /> Additionally, note the difference between the strength of the probability when going between the last two examples, the former representing a strong opinion and the latter a weaker opinion",
cache_examples=True
)
gradio_app.launch()