nizar-sayad's picture
Update handler.py
9b01ad2
raw
history blame
No virus
2.5 kB
from typing import Dict, List, Any
from transformers import AutoTokenizer, AutoModelForCausalLM, StoppingCriteria, StoppingCriteriaList,StoppingCriteria
import torch
class EndpointHandler():
def __init__(self, path="."):
# load model and processor from path
self.model = AutoModelForCausalLM.from_pretrained(path)
self.tokenizer = AutoTokenizer.from_pretrained(path)
# Create a stopping criteria class
class KeywordsStoppingCriteria(StoppingCriteria):
def __init__(self, keywords_ids: list, occurrences: int):
super().__init__()
self.keywords = keywords_ids
self.occurrences = occurrences
self.count = 0
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
if input_ids[0][-1] in self.keywords:
self.count += 1
if self.count == self.occurrences:
return True
return False
def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]:
# process input
inputs = data["inputs"]
stop_words = ['.']
stop_ids = [self.tokenizer.encode(w)[1] for w in stop_words]
gen_outputs = []
gen_outputs_no_input = []
gen_input = self.tokenizer(input, return_tensors="pt")
for _ in range(5):
stop_criteria = KeywordsStoppingCriteria(stop_ids, occurrences=2)
gen_output = self.model.generate(gen_input.input_ids, do_sample=True,
top_k=10,
top_p=0.95,
max_new_tokens=100,
penalty_alpha=0.6,
stopping_criteria=StoppingCriteriaList([stop_criteria])
)
gen_outputs.append(gen_output)
gen_outputs_no_input.append(gen_output[0][len(gen_input.input_ids[0]):])
gen_outputs_decoded = [self.tokenizer.decode(gen_output[0], skip_special_tokens=True) for gen_output in gen_outputs]
gen_outputs_no_input_decoded = [self.tokenizer.decode(gen_output_no_input, skip_special_tokens=True) for gen_output_no_input in gen_outputs_no_input]
return {"gen_outputs_decoded": gen_outputs_decoded, "gen_outputs_no_input_decoded": gen_outputs_no_input_decoded}