h2ogpt-chatbot / stopping.py
pseudotensor's picture
Update with h2oGPT hash 321b64cb3f52f33f91b6052f5edea590b829aaa1
65121b5
raw
history blame
1.07 kB
import traceback
from queue import Queue
from threading import Thread
import collections.abc
import torch
from transformers import StoppingCriteria
class StoppingCriteriaSub(StoppingCriteria):
def __init__(self, stops=[], encounters=[]):
super().__init__()
assert len(stops) % len(encounters) == 0, "Number of stops and encounters must match"
self.encounters = encounters
self.stops = [stop.to("cuda") for stop in stops]
self.num_stops = [0] * len(stops)
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
for stopi, stop in enumerate(self.stops):
if torch.all((stop == input_ids[0][-len(stop):])).item():
self.num_stops[stopi] += 1
if self.num_stops[stopi] >= self.encounters[stopi % len(self.encounters)]:
return True
# print("Tokens: %s" % input_ids[0].cpu().numpy(), flush=True)
# print("Stop Tokens: %s" % [x.cpu().numpy() for x in self.stops], flush=True)
return False