File size: 2,812 Bytes
3fb5cbe
 
c865559
3fb5cbe
 
 
 
 
c865559
3fb5cbe
 
 
 
 
 
 
 
 
 
 
0383b1e
 
3fb5cbe
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
from torch.nn.functional import softmax
from transformers import MT5ForConditionalGeneration, MT5Tokenizer
import streamlit as st

def process_nli(premise: str, hypothesis: str):
    """ process to required xnli format with task prefix """
    return "".join(['xnli: premise: ', premise, ' hypothesis: ', hypothesis])

@st.cache(allow_output_mutation=True)
def setModel(model_name):
    tokenizer = MT5Tokenizer.from_pretrained(model_name)
    model = MT5ForConditionalGeneration.from_pretrained(model_name)
    model.eval()
    return model, tokenizer

def runModel(model_name, sequence_to_classify, candidate_labels, hypothesis_template):
    ENTAILS_LABEL = "▁0"
    NEUTRAL_LABEL = "▁1"
    CONTRADICTS_LABEL = "▁2"

    model, tokenizer = setModel(model_name)
    
    label_inds = tokenizer.convert_tokens_to_ids([ENTAILS_LABEL, NEUTRAL_LABEL, CONTRADICTS_LABEL])

    # construct sequence of premise, hypothesis pairs
    pairs = [(sequence_to_classify, hypothesis_template.format(label)) for label in candidate_labels]
    # format for mt5 xnli task
    seqs = [process_nli(premise=premise, hypothesis=hypothesis) for premise, hypothesis in pairs]

    inputs = tokenizer.batch_encode_plus(seqs, return_tensors="pt", padding=True)
    out = model.generate(**inputs, output_scores=True, return_dict_in_generate=True, num_beams=1)

    # sanity check that our sequences are expected length (1 + start token + end token = 3)
    for i, seq in enumerate(out.sequences):
        assert len(seq) == 3

    # get the scores for our only token of interest
    # we'll now treat these like the output logits of a `*ForSequenceClassification` model
    scores = out.scores[0]

    # scores has a size of the model's vocab.
    # However, for this task we have a fixed set of labels
    # sanity check that these labels are always the top 3 scoring
    for i, sequence_scores in enumerate(scores):
        top_scores = sequence_scores.argsort()[-3:]
        assert set(top_scores.tolist()) == set(label_inds)

    # cut down scores to our task labels
    scores = scores[:, label_inds]

    # new indices of entailment and contradiction in scores
    entailment_ind = 0
    contradiction_ind = 2

    # we can show, per item, the entailment vs contradiction probas
    entail_vs_contra_scores = scores[:, [entailment_ind, contradiction_ind]]
    entail_vs_contra_probas = softmax(entail_vs_contra_scores, dim=1)
    
    # or we can show probas similar to `ZeroShotClassificationPipeline`
    # this gives a zero-shot classification style output across labels
    entail_scores = scores[:, entailment_ind]
    entail_probas = softmax(entail_scores, dim=0)
    
    dd = dict(zip(candidate_labels, entail_probas.tolist()))
    ddd = dict(sorted(dd.items(), key = lambda x: x[1], reverse = True))
    return ddd