from typing import Dict, List, Any from transformers import AutoTokenizer, AutoModelForCausalLM, StoppingCriteria, StoppingCriteriaList import torch from accelerate import Accelerator import bitsandbytes as bnb accelerator = Accelerator() # Create a stopping criteria class class KeywordsStoppingCriteria(StoppingCriteria): def __init__(self, keywords_ids: list, occurrences: int): super().__init__() self.keywords = keywords_ids self.occurrences = occurrences self.count = 0 def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool: if input_ids[0][-1] in self.keywords: self.count += 1 if self.count == self.occurrences: return True return False class EndpointHandler: def __init__(self, path=""): # load model and processor from path self.model = AutoModelForCausalLM.from_pretrained(path, device_map="auto", load_in_8bit=True) self.tokenizer = AutoTokenizer.from_pretrained(path) def __call__(self, data: Dict[str, Any]) -> Dict[str, str]: """ Args: data (:dict:): The payload with the text prompt. """ # process input input = data.pop("input", data) stop_words = ['.'] stop_ids = [self.tokenizer.encode(w)[1] for w in stop_words] gen_outputs = [] gen_outputs_no_input = [] gen_input = self.tokenizer(input, return_tensors="pt") for _ in range(5): stop_criteria = KeywordsStoppingCriteria(stop_ids, occurrences=2) gen_output = self.model.generate(gen_input.input_ids, do_sample=True, top_k=10, top_p=0.95, max_new_tokens=100, penalty_alpha=0.6, stopping_criteria=StoppingCriteriaList([stop_criteria]) ) gen_outputs.append(gen_output) gen_outputs_no_input.append(gen_output[0][len(gen_input.input_ids[0]):]) gen_outputs_decoded = [self.tokenizer.decode(gen_output[0], skip_special_tokens=True) for gen_output in gen_outputs] gen_outputs_no_input_decoded = [self.tokenizer.decode(gen_output_no_input, skip_special_tokens=True) for gen_output_no_input in gen_outputs_no_input] return {"gen_outputs_decoded": gen_outputs_decoded, "gen_outputs_no_input_decoded": gen_outputs_no_input_decoded}