tst / styleformer /adequacy.py
CAGmllab's picture
Upload 7 files
30c6353
raw
history blame contribute delete
No virus
1.61 kB
class Adequacy():
def __init__(self, model_tag='prithivida/parrot_adequacy_model'):
from transformers import AutoModelForSequenceClassification, AutoTokenizer
self.adequacy_model = AutoModelForSequenceClassification.from_pretrained(model_tag)
self.tokenizer = AutoTokenizer.from_pretrained(model_tag)
def filter(self, input_phrase, para_phrases, adequacy_threshold, device="cpu"):
top_adequacy_phrases = []
for para_phrase in para_phrases:
x = self.tokenizer(input_phrase, para_phrase, return_tensors='pt', max_length=128, truncation=True)
self.adequacy_model = self.adequacy_model.to(device)
logits = self.adequacy_model(**x).logits
probs = logits.softmax(dim=1)
prob_label_is_true = probs[:,1]
adequacy_score = prob_label_is_true.item()
if adequacy_score >= adequacy_threshold:
top_adequacy_phrases.append(para_phrase)
return top_adequacy_phrases
def score(self, input_phrase, para_phrases, adequacy_threshold, device="cpu"):
adequacy_scores = {}
for para_phrase in para_phrases:
x = self.tokenizer(input_phrase, para_phrase, return_tensors='pt', max_length=128, truncation=True)
x = x.to(device)
self.adequacy_model = self.adequacy_model.to(device)
logits = self.adequacy_model(**x).logits
probs = logits.softmax(dim=1)
prob_label_is_true = probs[:,1]
adequacy_score = prob_label_is_true.item()
if adequacy_score >= adequacy_threshold:
adequacy_scores[para_phrase] = adequacy_score
return adequacy_scores