|
import torch |
|
from transformers import StoppingCriteria |
|
|
|
class KeywordsStoppingCriteria(StoppingCriteria): |
|
def __init__(self, keywords, tokenizer, input_ids): |
|
self.keywords = keywords |
|
self.keyword_ids = [tokenizer(keyword).input_ids for keyword in keywords] |
|
self.keyword_ids = [keyword_id[0] for keyword_id in self.keyword_ids if type(keyword_id) is list and len(keyword_id) == 1] |
|
self.tokenizer = tokenizer |
|
self.start_len = None |
|
self.input_ids = input_ids |
|
|
|
def __call__(self, output_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool: |
|
if self.start_len is None: |
|
self.start_len = self.input_ids.shape[1] |
|
else: |
|
for keyword_id in self.keyword_ids: |
|
if output_ids[0, -1] == keyword_id: |
|
return True |
|
outputs = self.tokenizer.batch_decode(output_ids[:, self.start_len:], skip_special_tokens=True)[0] |
|
for keyword in self.keywords: |
|
if keyword in outputs: |
|
return True |
|
return False |
|
|