import gc import json import torch from ts.torch_handler.base_handler import BaseHandler from transformers import AutoModelForCausalLM, AutoTokenizer import logging logger = logging.getLogger(__name__) class TextGenerationHandlerForString(BaseHandler): def __init__(self): super(TextGenerationHandlerForString, self).__init__() self.model = None self.tokenizer = None self.device = None self.task_config = None self.initialized = False def load_model(self, model_dir): if self.device.type == "cuda": self.model = AutoModelForCausalLM.from_pretrained(model_dir, torch_dtype="auto", low_cpu_mem_usage=True) if self.model.dtype == torch.float32: self.model = self.model.half() else: self.model = AutoModelForCausalLM.from_pretrained(model_dir, torch_dtype="auto") self.tokenizer = AutoTokenizer.from_pretrained(model_dir) try: self.task_config = self.model.config.task_specific_params["text-generation"] except Exception: self.task_config = {} # TODO: Need to compare performance self.model.to(self.device, non_blocking=True) def initialize(self, ctx): self.manifest = ctx.manifest properties = ctx.system_properties model_dir = properties.get("model_dir") self.device = torch.device( "cuda:" + str(properties.get("gpu_id")) if torch.cuda.is_available() else "cpu" ) self.load_model(model_dir) self.model.eval() self.initialized = True def preprocess(self, requests): input_batch = {} for idx, data in enumerate(requests): input_batch["input_text"] = data.get("body").get("text") input_batch["num_samples"] = data.get("body").get("num_samples") input_batch["length"] = data.get("body").get("length") del requests gc.collect() return input_batch def inference(self, input_batch): input_text = input_batch["input_text"] length = input_batch["length"] num_samples = input_batch["num_samples"] input_ids = self.tokenizer.encode(input_text, return_tensors="pt").to( self.device ) self.task_config["max_length"] = length self.task_config["num_return_sequences"] = num_samples inference_output = self.model.generate(input_ids, **self.task_config) if torch.cuda.is_available(): torch.cuda.empty_cache() del input_batch gc.collect() return inference_output def postprocess(self, inference_output): output = self.tokenizer.batch_decode( inference_output.tolist(), skip_special_tokens=True ) del inference_output gc.collect() return [json.dumps(output, ensure_ascii=False)] def handle(self, data, context): self.context = context data = self.preprocess(data) data = self.inference(data) data = self.postprocess(data) return data