viethoangtranduong's picture
Update handler.py
85e6e1b
raw
history blame
6.05 kB
# import torch
# from typing import Dict, List, Any
# from transformers import AutoTokenizer, AutoModelForCausalLM
# class EndpointHandler:
# def __init__(self, path: str = ""):
# self.tokenizer = AutoTokenizer.from_pretrained(path, padding_side = "left")
# self.model = AutoModelForCausalLM.from_pretrained(path, device_map = "auto", torch_dtype=torch.float16)
# def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]:
# """
# Args:
# data (:obj:):
# includes the input data and the parameters for the inference.
# Return:
# A :obj:`list`:. The list contains the answer and scores of the inference inputs
# """
# # process input
# inputs_dict = data.pop("inputs", data)
# parameters = data.pop("parameters", {})
# prompts = [f"<human>: {prompt}\n<bot>:" for prompt in inputs_dict]
# self.tokenizer.pad_token = self.tokenizer.eos_token
# inputs = self.tokenizer(prompts, truncation=True, max_length=2048-512,
# return_tensors='pt', padding=True).to(self.model.device)
# input_length = inputs.input_ids.shape[1]
# if parameters.get("deterministic", False):
# torch.manual_seed(42)
# outputs = self.model.generate(
# **inputs, max_new_tokens=512, do_sample=True, temperature=0.7, top_p=0.7, top_k=50
# )
# output_strs = self.tokenizer.batch_decode(outputs[:, input_length:], skip_special_tokens=True)
# return {"generated_text": output_strs}
# import torch
# from typing import Dict, List, Any
# from transformers import AutoTokenizer, AutoModelForCausalLM
# class EndpointHandler():
# def __init__(self, path: str = ""):
# self.tokenizer = AutoTokenizer.from_pretrained(path, padding_side = "left")
# self.model = AutoModelForCausalLM.from_pretrained(path, device_map = "auto", torch_dtype=torch.float16)
# def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]:
# """
# Args:
# data (:obj:):
# includes the input data and the parameters for the inference.
# Return:
# A :obj:`list`:. The list contains the answer and scores of the inference inputs
# """
# # process input
# inputs_list = data.pop("inputs", data)
# parameters = data.pop("parameters", {})
# prompts = [f"<human>: {prompt}\n<bot>:" for prompt in inputs_list]
# self.tokenizer.pad_token = self.tokenizer.eos_token
# inputs = self.tokenizer(prompts, truncation=True, max_length=2048-512,
# return_tensors='pt', padding=True).to(self.model.device)
# input_length = inputs.input_ids.shape[1]
# if parameters.get("deterministic", False):
# torch.manual_seed(42)
# outputs = self.model.generate(
# **inputs, max_new_tokens=512, do_sample=True, temperature=0.7, top_p=0.7, top_k=50
# )
# output_strs = self.tokenizer.batch_decode(outputs[:, input_length:], skip_special_tokens=True)
# return {"generated_text": output_strs}
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM, StoppingCriteria, StoppingCriteriaList
from typing import Dict, List, Any
class StopWordsCriteria(StoppingCriteria):
def __init__(self, stop_words, tokenizer):
self.tokenizer = tokenizer
self.stop_words = stop_words
self._cache_str = ''
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
self._cache_str += self.tokenizer.decode(input_ids[0, -1])
for stop_words in self.stop_words:
if stop_words in self._cache_str:
return True
return False
class EndpointHandler():
def __init__(self, path: str = ""):
self.tokenizer = AutoTokenizer.from_pretrained(path, padding_side = "left")
self.tokenizer.pad_token = self.tokenizer.eos_token
self.model = AutoModelForCausalLM.from_pretrained(path, device_map = "auto", torch_dtype=torch.float16)
def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]:
"""
Args:
data (:obj:):
includes the input data and the parameters for the inference.
Return:
A :obj:`list`:. The list contains the answer and scores of the inference inputs
"""
# process input
inputs_list = data.pop("inputs", data)
parameters = data.pop("parameters", {})
prompts = [f"<human>: {prompt}\n<bot>:" for prompt in inputs_list]
if parameters.get("preset_truncation_token"):
preset_truncation_token_value = parameters["preset_truncation_token"]
DELIMETER = " "
prompts = [DELIMETER.join(prompt.split(DELIMETER)[:preset_truncation_token_value]) for prompt in prompts]
print("45", prompts)
del parameters["preset_truncation_token"]
with torch.no_grad():
inputs = self.tokenizer(prompts, truncation=True, max_length=2048-512,
return_tensors='pt', padding=True).to(self.model.device)
input_length = inputs.input_ids.shape[1]
if parameters.get("deterministic_seed", False):
torch.manual_seed(parameters["deterministic_seed"])
del parameters["deterministic_seed"]
outputs = self.model.generate(
**inputs, **parameters,
stopping_criteria=StoppingCriteriaList(
[StopWordsCriteria(['\n<human>:'], self.tokenizer)]
)
)
output_strs = self.tokenizer.batch_decode(outputs[:, input_length:], skip_special_tokens=True)
output_strs = [output_str.replace("\n<human>:", "") for output_str in output_strs]
return {"generated_text": output_strs}