import torch import gc from ts.torch_handler.base_handler import BaseHandler from transformers import GPT2LMHeadModel import logging logger = logging.getLogger(__name__) class SampleTransformerModel(BaseHandler): def __init__(self): super(SampleTransformerModel, self).__init__() self.model = None self.device = None self.initialized = False def load_model(self, model_dir): self.model = GPT2LMHeadModel.from_pretrained(model_dir, return_dict=True) self.model.to(self.device) def initialize(self, ctx): # self.manifest = ctx.manifest properties = ctx.system_properties model_dir = properties.get("model_dir") self.device = torch.device("cuda:" + str(properties.get("gpu_id")) if torch.cuda.is_available() else "cpu") self.load_model(model_dir) self.model.eval() self.initialized = True def preprocess(self, requests): input_batch = {} for idx, data in enumerate(requests): input_ids = torch.tensor([data.get("body").get("text")]).to(self.device) input_batch["input_ids"] = input_ids input_batch["num_samples"] = data.get("body").get("num_samples") input_batch["length"] = data.get("body").get("length") + len(data.get("body").get("text")) del requests gc.collect() return input_batch def inference(self, input_batch): input_ids = input_batch["input_ids"] length = input_batch["length"] inference_output = self.model.generate(input_ids, bos_token_id=self.model.config.bos_token_id, eos_token_id=self.model.config.eos_token_id, pad_token_id=self.model.config.eos_token_id, do_sample=True, max_length=length, top_k=50, top_p=0.95, no_repeat_ngram_size=2, num_return_sequences=input_batch["num_samples"]) if torch.cuda.is_available(): torch.cuda.empty_cache() del input_batch gc.collect() return inference_output def postprocess(self, inference_output): output = inference_output.cpu().numpy().tolist() del inference_output gc.collect() return [output] def handle(self, data, context): # self.context = context data = self.preprocess(data) data = self.inference(data) data = self.postprocess(data) return data