|
class Adequacy(): |
|
|
|
def __init__(self, model_tag='prithivida/parrot_adequacy_model'): |
|
from transformers import AutoModelForSequenceClassification, AutoTokenizer |
|
self.adequacy_model = AutoModelForSequenceClassification.from_pretrained(model_tag) |
|
self.tokenizer = AutoTokenizer.from_pretrained(model_tag) |
|
|
|
def filter(self, input_phrase, para_phrases, adequacy_threshold, device="cpu"): |
|
top_adequacy_phrases = [] |
|
for para_phrase in para_phrases: |
|
x = self.tokenizer(input_phrase, para_phrase, return_tensors='pt', max_length=128, truncation=True) |
|
self.adequacy_model = self.adequacy_model.to(device) |
|
logits = self.adequacy_model(**x).logits |
|
probs = logits.softmax(dim=1) |
|
prob_label_is_true = probs[:,1] |
|
adequacy_score = prob_label_is_true.item() |
|
if adequacy_score >= adequacy_threshold: |
|
top_adequacy_phrases.append(para_phrase) |
|
return top_adequacy_phrases |
|
|
|
|
|
def score(self, input_phrase, para_phrases, adequacy_threshold, device="cpu"): |
|
adequacy_scores = {} |
|
for para_phrase in para_phrases: |
|
x = self.tokenizer(input_phrase, para_phrase, return_tensors='pt', max_length=128, truncation=True) |
|
x = x.to(device) |
|
self.adequacy_model = self.adequacy_model.to(device) |
|
logits = self.adequacy_model(**x).logits |
|
probs = logits.softmax(dim=1) |
|
prob_label_is_true = probs[:,1] |
|
adequacy_score = prob_label_is_true.item() |
|
if adequacy_score >= adequacy_threshold: |
|
adequacy_scores[para_phrase] = adequacy_score |
|
return adequacy_scores |
|
|