Spaces:
Running
on
Zero
Running
on
Zero
import logging | |
from typing import List | |
import torch | |
from transformers import ( | |
LogitsProcessor, | |
) | |
class StopAfterTokenIsGenerated(LogitsProcessor): | |
def __init__(self, stops: List[torch.tensor], eos_token_id: int): | |
super().__init__() | |
self.stops = stops | |
self.eos_token_id = eos_token_id | |
logging.info(f"Stopping criteria words ids: {self.stops}") | |
self.first_batch = True | |
def __call__( | |
self, input_ids: torch.LongTensor, scores: torch.FloatTensor | |
) -> torch.FloatTensor: | |
""" | |
Args: | |
input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): | |
Indices of input sequence tokens in the vocabulary. [What are input IDs?](../glossary#input-ids) | |
scores (`torch.FloatTensor` of shape `(batch_size, config.vocab_size)`): | |
Prediction scores of a language modeling head. These can be logits for each vocabulary when not using beam | |
search or log softmax for each vocabulary token when using beam search | |
Return: | |
`torch.FloatTensor` of shape `(batch_size, config.vocab_size)`: The processed prediction scores. | |
""" | |
if self.first_batch: | |
self.first_batch = False | |
return scores | |
for seq_no, seq in enumerate(input_ids): | |
# logging.info(seq_no) | |
for stop in self.stops: | |
stop = stop.to(device=seq.device, dtype=seq.dtype) | |
if ( | |
len(seq) >= len(stop) | |
and torch.all((stop == seq[-len(stop) :])).item() | |
): | |
scores[seq_no, :] = -float("inf") | |
scores[seq_no, self.eos_token_id] = 0 | |
logging.info(f"Stopping criteria found: {stop}") | |
break | |
return scores | |
def reset(self): | |
self.first_batch = True | |