|
from typing import Dict, List, Any |
|
from transformers import AutoTokenizer, AutoModelForCausalLM, StoppingCriteria, StoppingCriteriaList |
|
import torch |
|
from accelerate import Accelerator |
|
import bitsandbytes as bnb |
|
|
|
accelerator = Accelerator() |
|
|
|
|
|
class KeywordsStoppingCriteria(StoppingCriteria): |
|
def __init__(self, keywords_ids: list, occurrences: int): |
|
super().__init__() |
|
self.keywords = keywords_ids |
|
self.occurrences = occurrences |
|
self.count = 0 |
|
|
|
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool: |
|
if input_ids[0][-1] in self.keywords: |
|
self.count += 1 |
|
if self.count == self.occurrences: |
|
return True |
|
return False |
|
|
|
class EndpointHandler: |
|
def __init__(self, path=""): |
|
|
|
self.model = AutoModelForCausalLM.from_pretrained(path, device_map="auto", load_in_8bit=True) |
|
self.tokenizer = AutoTokenizer.from_pretrained(path) |
|
|
|
def __call__(self, data: Dict[str, Any]) -> Dict[str, str]: |
|
""" |
|
Args: |
|
data (:dict:): |
|
The payload with the text prompt. |
|
""" |
|
|
|
input = data.pop("input", data) |
|
|
|
stop_words = ['.'] |
|
stop_ids = [self.tokenizer.encode(w)[1] for w in stop_words] |
|
gen_outputs = [] |
|
gen_outputs_no_input = [] |
|
gen_input = self.tokenizer(input, return_tensors="pt") |
|
for _ in range(5): |
|
stop_criteria = KeywordsStoppingCriteria(stop_ids, occurrences=2) |
|
gen_output = self.model.generate(gen_input.input_ids, do_sample=True, |
|
top_k=10, |
|
top_p=0.95, |
|
max_new_tokens=100, |
|
penalty_alpha=0.6, |
|
stopping_criteria=StoppingCriteriaList([stop_criteria]) |
|
) |
|
gen_outputs.append(gen_output) |
|
gen_outputs_no_input.append(gen_output[0][len(gen_input.input_ids[0]):]) |
|
|
|
gen_outputs_decoded = [self.tokenizer.decode(gen_output[0], skip_special_tokens=True) for gen_output in gen_outputs] |
|
gen_outputs_no_input_decoded = [self.tokenizer.decode(gen_output_no_input, skip_special_tokens=True) for gen_output_no_input in gen_outputs_no_input] |
|
|
|
return {"gen_outputs_decoded": gen_outputs_decoded, "gen_outputs_no_input_decoded": gen_outputs_no_input_decoded} |