from torch.nn.functional import softmax from transformers import MT5ForConditionalGeneration, MT5Tokenizer import streamlit as st def process_nli(premise: str, hypothesis: str): """ process to required xnli format with task prefix """ return "".join(['xnli: premise: ', premise, ' hypothesis: ', hypothesis]) @st.cache(allow_output_mutation=True) def setModel(model_name): tokenizer = MT5Tokenizer.from_pretrained(model_name) model = MT5ForConditionalGeneration.from_pretrained(model_name) model.eval() return model, tokenizer def runModel(model_name, sequence_to_classify, candidate_labels, hypothesis_template): ENTAILS_LABEL = "▁0" NEUTRAL_LABEL = "▁1" CONTRADICTS_LABEL = "▁2" model, tokenizer = setModel(model_name) label_inds = tokenizer.convert_tokens_to_ids([ENTAILS_LABEL, NEUTRAL_LABEL, CONTRADICTS_LABEL]) # construct sequence of premise, hypothesis pairs pairs = [(sequence_to_classify, hypothesis_template.format(label)) for label in candidate_labels] # format for mt5 xnli task seqs = [process_nli(premise=premise, hypothesis=hypothesis) for premise, hypothesis in pairs] inputs = tokenizer.batch_encode_plus(seqs, return_tensors="pt", padding=True) out = model.generate(**inputs, output_scores=True, return_dict_in_generate=True, num_beams=1) # sanity check that our sequences are expected length (1 + start token + end token = 3) for i, seq in enumerate(out.sequences): assert len(seq) == 3 # get the scores for our only token of interest # we'll now treat these like the output logits of a `*ForSequenceClassification` model scores = out.scores[0] # scores has a size of the model's vocab. # However, for this task we have a fixed set of labels # sanity check that these labels are always the top 3 scoring for i, sequence_scores in enumerate(scores): top_scores = sequence_scores.argsort()[-3:] assert set(top_scores.tolist()) == set(label_inds) # cut down scores to our task labels scores = scores[:, label_inds] # new indices of entailment and contradiction in scores entailment_ind = 0 contradiction_ind = 2 # we can show, per item, the entailment vs contradiction probas entail_vs_contra_scores = scores[:, [entailment_ind, contradiction_ind]] entail_vs_contra_probas = softmax(entail_vs_contra_scores, dim=1) # or we can show probas similar to `ZeroShotClassificationPipeline` # this gives a zero-shot classification style output across labels entail_scores = scores[:, entailment_ind] entail_probas = softmax(entail_scores, dim=0) dd = dict(zip(candidate_labels, entail_probas.tolist())) ddd = dict(sorted(dd.items(), key = lambda x: x[1], reverse = True)) return ddd