davidkim205's picture
Upload folder using huggingface_hub
577164e verified
raw
history blame contribute delete
No virus
2.7 kB
import torch
import argparse
import os
import random
import os
import json
import random, os
import numpy as np
import torch
from transformers import StoppingCriteria, StoppingCriteriaList
from transformers import TextStreamer, GenerationConfig
class LocalStoppingCriteria(StoppingCriteria):
def __init__(self, tokenizer, stop_words=[]):
super().__init__()
stops = [tokenizer(stop_word, return_tensors='pt', add_special_tokens=False)['input_ids'].squeeze() for
stop_word in stop_words]
print('stop_words', stop_words)
print('stop_words_ids', stops)
self.stop_words = stop_words
self.stops = [stop.cuda() for stop in stops]
self.tokenizer = tokenizer
def _compare_token(self, input_ids):
for stop in self.stops:
if len(stop.size()) != 1:
continue
stop_len = len(stop)
if torch.all((stop == input_ids[0][-stop_len:])).item():
return True
return False
def _compare_decode(self, input_ids):
input_str = self.tokenizer.decode(input_ids[0])
for stop_word in self.stop_words:
if input_str.endswith(stop_word):
return True
return False
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor):
return self._compare_decode(input_ids)
def seed_everything(seed: int):
random.seed(seed)
os.environ['PYTHONHASHSEED'] = str(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = True
def generation(model, tokenizer, x, max_new_tokens=1024):
stopping_criteria = StoppingCriteriaList(
[LocalStoppingCriteria(tokenizer=tokenizer, stop_words=[tokenizer.eos_token])])
streamer = TextStreamer(tokenizer)
generation_config = GenerationConfig(
temperature=1.0,
top_p=0.8,
top_k=100,
max_new_tokens=max_new_tokens,
early_stopping=True,
do_sample=True,
)
gened = model.generate(
**tokenizer(
x,
return_tensors='pt',
return_token_type_ids=False
).to('cuda'),
generation_config=generation_config,
pad_token_id=tokenizer.eos_token_id,
eos_token_id=tokenizer.eos_token_id,
stopping_criteria=stopping_criteria,
streamer=streamer,
)
response = tokenizer.decode(gened[0])
only_gen_text = response.split(x)
if len(only_gen_text) == 2:
response = only_gen_text[-1]
response = response.replace(tokenizer.eos_token, '')
return response