|
import gc |
|
import json |
|
|
|
import torch |
|
from ts.torch_handler.base_handler import BaseHandler |
|
from transformers import AutoModelForCausalLM, AutoTokenizer |
|
|
|
import logging |
|
|
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
class TextGenerationHandlerForString(BaseHandler): |
|
def __init__(self): |
|
super(TextGenerationHandlerForString, self).__init__() |
|
self.model = None |
|
self.tokenizer = None |
|
self.device = None |
|
self.task_config = None |
|
self.initialized = False |
|
|
|
def load_model(self, model_dir): |
|
if self.device.type == "cuda": |
|
self.model = AutoModelForCausalLM.from_pretrained(model_dir, torch_dtype="auto", low_cpu_mem_usage=True) |
|
if self.model.dtype == torch.float32: |
|
self.model = self.model.half() |
|
else: |
|
self.model = AutoModelForCausalLM.from_pretrained(model_dir, torch_dtype="auto") |
|
self.tokenizer = AutoTokenizer.from_pretrained(model_dir) |
|
try: |
|
self.task_config = self.model.config.task_specific_params["text-generation"] |
|
except Exception: |
|
self.task_config = {} |
|
|
|
self.model.to(self.device, non_blocking=True) |
|
|
|
def initialize(self, ctx): |
|
self.manifest = ctx.manifest |
|
properties = ctx.system_properties |
|
model_dir = properties.get("model_dir") |
|
self.device = torch.device( |
|
"cuda:" + str(properties.get("gpu_id")) |
|
if torch.cuda.is_available() |
|
else "cpu" |
|
) |
|
self.load_model(model_dir) |
|
self.model.eval() |
|
self.initialized = True |
|
|
|
def preprocess(self, requests): |
|
input_batch = {} |
|
for idx, data in enumerate(requests): |
|
input_batch["input_text"] = data.get("body").get("text") |
|
input_batch["num_samples"] = data.get("body").get("num_samples") |
|
input_batch["length"] = data.get("body").get("length") |
|
del requests |
|
gc.collect() |
|
return input_batch |
|
|
|
def inference(self, input_batch): |
|
input_text = input_batch["input_text"] |
|
length = input_batch["length"] |
|
num_samples = input_batch["num_samples"] |
|
input_ids = self.tokenizer.encode(input_text, return_tensors="pt").to( |
|
self.device |
|
) |
|
self.task_config["max_length"] = length |
|
self.task_config["num_return_sequences"] = num_samples |
|
inference_output = self.model.generate(input_ids, **self.task_config) |
|
if torch.cuda.is_available(): |
|
torch.cuda.empty_cache() |
|
del input_batch |
|
gc.collect() |
|
return inference_output |
|
|
|
def postprocess(self, inference_output): |
|
output = self.tokenizer.batch_decode( |
|
inference_output.tolist(), skip_special_tokens=True |
|
) |
|
del inference_output |
|
gc.collect() |
|
return [json.dumps(output, ensure_ascii=False)] |
|
|
|
def handle(self, data, context): |
|
self.context = context |
|
data = self.preprocess(data) |
|
data = self.inference(data) |
|
data = self.postprocess(data) |
|
return data |
|
|