File size: 2,502 Bytes
1d29082 dd60530 1d29082 dd60530 1d29082 ec9ec3e 9b01ad2 1d29082 d5680c2 1d29082 34bb7b3 dd60530 34bb7b3 228ed32 1d29082 34bb7b3 1d29082 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 |
from typing import Dict, List, Any
from transformers import AutoTokenizer, AutoModelForCausalLM, StoppingCriteria, StoppingCriteriaList,StoppingCriteria
import torch
class EndpointHandler():
def __init__(self, path="."):
# load model and processor from path
self.model = AutoModelForCausalLM.from_pretrained(path)
self.tokenizer = AutoTokenizer.from_pretrained(path)
# Create a stopping criteria class
class KeywordsStoppingCriteria(StoppingCriteria):
def __init__(self, keywords_ids: list, occurrences: int):
super().__init__()
self.keywords = keywords_ids
self.occurrences = occurrences
self.count = 0
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
if input_ids[0][-1] in self.keywords:
self.count += 1
if self.count == self.occurrences:
return True
return False
def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]:
# process input
inputs = data["inputs"]
stop_words = ['.']
stop_ids = [self.tokenizer.encode(w)[1] for w in stop_words]
gen_outputs = []
gen_outputs_no_input = []
gen_input = self.tokenizer(input, return_tensors="pt")
for _ in range(5):
stop_criteria = KeywordsStoppingCriteria(stop_ids, occurrences=2)
gen_output = self.model.generate(gen_input.input_ids, do_sample=True,
top_k=10,
top_p=0.95,
max_new_tokens=100,
penalty_alpha=0.6,
stopping_criteria=StoppingCriteriaList([stop_criteria])
)
gen_outputs.append(gen_output)
gen_outputs_no_input.append(gen_output[0][len(gen_input.input_ids[0]):])
gen_outputs_decoded = [self.tokenizer.decode(gen_output[0], skip_special_tokens=True) for gen_output in gen_outputs]
gen_outputs_no_input_decoded = [self.tokenizer.decode(gen_output_no_input, skip_special_tokens=True) for gen_output_no_input in gen_outputs_no_input]
return {"gen_outputs_decoded": gen_outputs_decoded, "gen_outputs_no_input_decoded": gen_outputs_no_input_decoded} |