|
from torch.nn.functional import softmax |
|
from transformers import MT5ForConditionalGeneration, MT5Tokenizer |
|
import streamlit as st |
|
|
|
def process_nli(premise: str, hypothesis: str): |
|
""" process to required xnli format with task prefix """ |
|
return "".join(['xnli: premise: ', premise, ' hypothesis: ', hypothesis]) |
|
|
|
@st.cache(allow_output_mutation=True) |
|
def setModel(model_name): |
|
tokenizer = MT5Tokenizer.from_pretrained(model_name) |
|
model = MT5ForConditionalGeneration.from_pretrained(model_name) |
|
model.eval() |
|
return model, tokenizer |
|
|
|
def runModel(model_name, sequence_to_classify, candidate_labels, hypothesis_template): |
|
ENTAILS_LABEL = "β0" |
|
NEUTRAL_LABEL = "β1" |
|
CONTRADICTS_LABEL = "β2" |
|
|
|
model, tokenizer = setModel(model_name) |
|
|
|
label_inds = tokenizer.convert_tokens_to_ids([ENTAILS_LABEL, NEUTRAL_LABEL, CONTRADICTS_LABEL]) |
|
|
|
|
|
pairs = [(sequence_to_classify, hypothesis_template.format(label)) for label in candidate_labels] |
|
|
|
seqs = [process_nli(premise=premise, hypothesis=hypothesis) for premise, hypothesis in pairs] |
|
|
|
inputs = tokenizer.batch_encode_plus(seqs, return_tensors="pt", padding=True) |
|
out = model.generate(**inputs, output_scores=True, return_dict_in_generate=True, num_beams=1) |
|
|
|
|
|
for i, seq in enumerate(out.sequences): |
|
assert len(seq) == 3 |
|
|
|
|
|
|
|
scores = out.scores[0] |
|
|
|
|
|
|
|
|
|
for i, sequence_scores in enumerate(scores): |
|
top_scores = sequence_scores.argsort()[-3:] |
|
assert set(top_scores.tolist()) == set(label_inds) |
|
|
|
|
|
scores = scores[:, label_inds] |
|
|
|
|
|
entailment_ind = 0 |
|
contradiction_ind = 2 |
|
|
|
|
|
entail_vs_contra_scores = scores[:, [entailment_ind, contradiction_ind]] |
|
entail_vs_contra_probas = softmax(entail_vs_contra_scores, dim=1) |
|
|
|
|
|
|
|
entail_scores = scores[:, entailment_ind] |
|
entail_probas = softmax(entail_scores, dim=0) |
|
|
|
dd = dict(zip(candidate_labels, entail_probas.tolist())) |
|
ddd = dict(sorted(dd.items(), key = lambda x: x[1], reverse = True)) |
|
return ddd |