File size: 1,610 Bytes
30c6353
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
class Adequacy():
  
  def __init__(self, model_tag='prithivida/parrot_adequacy_model'):
    from transformers import AutoModelForSequenceClassification, AutoTokenizer
    self.adequacy_model = AutoModelForSequenceClassification.from_pretrained(model_tag)
    self.tokenizer = AutoTokenizer.from_pretrained(model_tag)

  def filter(self, input_phrase, para_phrases, adequacy_threshold, device="cpu"):
      top_adequacy_phrases = []
      for para_phrase in para_phrases:
        x = self.tokenizer(input_phrase, para_phrase, return_tensors='pt', max_length=128, truncation=True)
        self.adequacy_model = self.adequacy_model.to(device)
        logits = self.adequacy_model(**x).logits
        probs = logits.softmax(dim=1)
        prob_label_is_true = probs[:,1]
        adequacy_score = prob_label_is_true.item()
        if adequacy_score >= adequacy_threshold:
            top_adequacy_phrases.append(para_phrase)
      return top_adequacy_phrases


  def score(self, input_phrase, para_phrases, adequacy_threshold, device="cpu"):
      adequacy_scores = {}
      for para_phrase in para_phrases:
        x = self.tokenizer(input_phrase, para_phrase, return_tensors='pt', max_length=128, truncation=True)
        x = x.to(device)
        self.adequacy_model = self.adequacy_model.to(device)
        logits = self.adequacy_model(**x).logits
        probs = logits.softmax(dim=1)
        prob_label_is_true = probs[:,1]
        adequacy_score = prob_label_is_true.item()
        if adequacy_score >= adequacy_threshold:
          adequacy_scores[para_phrase] = adequacy_score
      return adequacy_scores