nizar-sayad's picture
custom handler
1d29082
raw
history blame
No virus
2.62 kB
from typing import Dict, List, Any
from transformers import AutoTokenizer, AutoModelForCausalLM, StoppingCriteria, StoppingCriteriaList
import torch
from accelerate import Accelerator
import bitsandbytes as bnb
accelerator = Accelerator()
# Create a stopping criteria class
class KeywordsStoppingCriteria(StoppingCriteria):
def __init__(self, keywords_ids: list, occurrences: int):
super().__init__()
self.keywords = keywords_ids
self.occurrences = occurrences
self.count = 0
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
if input_ids[0][-1] in self.keywords:
self.count += 1
if self.count == self.occurrences:
return True
return False
class EndpointHandler:
def __init__(self, path=""):
# load model and processor from path
self.model = AutoModelForCausalLM.from_pretrained(path, device_map="auto", load_in_8bit=True)
self.tokenizer = AutoTokenizer.from_pretrained(path)
def __call__(self, data: Dict[str, Any]) -> Dict[str, str]:
"""
Args:
data (:dict:):
The payload with the text prompt.
"""
# process input
input = data.pop("input", data)
stop_words = ['.']
stop_ids = [self.tokenizer.encode(w)[1] for w in stop_words]
gen_outputs = []
gen_outputs_no_input = []
gen_input = self.tokenizer(input, return_tensors="pt")
for _ in range(5):
stop_criteria = KeywordsStoppingCriteria(stop_ids, occurrences=2)
gen_output = self.model.generate(gen_input.input_ids, do_sample=True,
top_k=10,
top_p=0.95,
max_new_tokens=100,
penalty_alpha=0.6,
stopping_criteria=StoppingCriteriaList([stop_criteria])
)
gen_outputs.append(gen_output)
gen_outputs_no_input.append(gen_output[0][len(gen_input.input_ids[0]):])
gen_outputs_decoded = [self.tokenizer.decode(gen_output[0], skip_special_tokens=True) for gen_output in gen_outputs]
gen_outputs_no_input_decoded = [self.tokenizer.decode(gen_output_no_input, skip_special_tokens=True) for gen_output_no_input in gen_outputs_no_input]
return {"gen_outputs_decoded": gen_outputs_decoded, "gen_outputs_no_input_decoded": gen_outputs_no_input_decoded}