test_endpoint2 / handler.py
jordiclive's picture
Update handler.py
2f6df11 verified
raw
history blame
2.26 kB
from typing import Any, Dict, List
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline
dtype = torch.bfloat16
class EndpointHandler:
def __init__(self, path=""):
# load the model
self.tokenizer = AutoTokenizer.from_pretrained(path)
self.model = AutoModelForCausalLM.from_pretrained(
path, device_map="auto", torch_dtype=dtype
)
if self.tokenizer.pad_token is None:
self.tokenizer.pad_token = self.tokenizer.eos_token
# create inference pipeline
self.pipeline = pipeline(
"text-generation", model=self.model, tokenizer=self.tokenizer
)
self.ce = torch.nn.CrossEntropyLoss(
ignore_index=self.tokenizer.pad_token_id, reduction="none"
)
def compute_log_likelihood(self, lm_logits, input_ids):
predictions = lm_logits[..., :-1, :].contiguous()
target_ids = input_ids[..., 1:].contiguous()
ce_loss = self.ce(
predictions.view(-1, predictions.size(-1)),
target_ids.view(-1),
)
return -ce_loss.view_as(target_ids)[0]
def __call__(self, data: Any):
inputs = data.pop("inputs", data)
parameters = data.pop("parameters", None)
input_tokens = self.tokenizer.batch_encode_plus(
[inputs], return_tensors="pt", padding=False
)
for t in input_tokens:
if torch.is_tensor(input_tokens[t]):
input_tokens[t] = input_tokens[t].to(torch.cuda.current_device())
logits = self.model(
input_ids=input_tokens["input_ids"],
attention_mask=input_tokens["attention_mask"],
)[0]
log_likelihood = self.compute_log_likelihood(
logits, input_tokens["input_ids"]
)
return (logits, log_likelihood)
# if __name__ == "__main__":
# model = EndpointHandler("TinyLlama/TinyLlama-1.1B-Chat-v1.0")
# data = {
# "inputs": "Can you please let us know more details about your ",
# "parameters": {
# "no_generation": True,
# # "function_to_apply": "none",
# # "return_text": False,
# },
# }
# x = model(data)