yfdeng's picture
init
744eb4e
import torch
from transformers import StoppingCriteria
class KeywordsStoppingCriteria(StoppingCriteria):
def __init__(self, keywords, tokenizer, input_ids):
self.keywords = keywords
self.keyword_ids = [tokenizer(keyword).input_ids for keyword in keywords]
self.keyword_ids = [keyword_id[0] for keyword_id in self.keyword_ids if type(keyword_id) is list and len(keyword_id) == 1]
self.tokenizer = tokenizer
self.start_len = None
self.input_ids = input_ids
def __call__(self, output_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
if self.start_len is None:
self.start_len = self.input_ids.shape[1]
else:
for keyword_id in self.keyword_ids:
if output_ids[0, -1] == keyword_id:
return True
outputs = self.tokenizer.batch_decode(output_ids[:, self.start_len:], skip_special_tokens=True)[0]
for keyword in self.keywords:
if keyword in outputs:
return True
return False