Research-chatbot / stopping.py
pseudotensor's picture
Update with h2oGPT hash 61628d335bdb685fdcc63ca9821cf5607f41a9e3
e0ba5f2
raw
history blame
No virus
1.08 kB
import traceback
from queue import Queue
from threading import Thread
import collections.abc
import torch
from transformers import StoppingCriteria
class StoppingCriteriaSub(StoppingCriteria):
def __init__(self, stops=[], encounters=[], device="cuda"):
super().__init__()
assert len(stops) % len(encounters) == 0, "Number of stops and encounters must match"
self.encounters = encounters
self.stops = [stop.to(device) for stop in stops]
self.num_stops = [0] * len(stops)
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
for stopi, stop in enumerate(self.stops):
if torch.all((stop == input_ids[0][-len(stop):])).item():
self.num_stops[stopi] += 1
if self.num_stops[stopi] >= self.encounters[stopi % len(self.encounters)]:
return True
# print("Tokens: %s" % input_ids[0].cpu().numpy(), flush=True)
# print("Stop Tokens: %s" % [x.cpu().numpy() for x in self.stops], flush=True)
return False