import logging from abc import ABC import diffusers import torch from diffusers import StableDiffusionXLPipeline from ts.torch_handler.base_handler import BaseHandler import numpy as np import base64 from io import BytesIO from PIL import Image import numpy as np import uuid from tempfile import TemporaryFile from google.cloud import storage logger = logging.getLogger(__name__) logger.info("Diffusers version %s", diffusers.__version__) class DiffusersHandler(BaseHandler, ABC): """ Diffusers handler class for text to image generation. """ def __init__(self): self.initialized = False def initialize(self, ctx): """In this initialize function, the Stable Diffusion model is loaded and initialized here. Args: ctx (context): It is a JSON Object containing information pertaining to the model artefacts parameters. """ logger.info("Loading diffusion model") logger.info("I'm totally new and updated") self.manifest = ctx.manifest properties = ctx.system_properties model_dir = properties.get("model_dir") device_str = "cuda:" + str(properties.get("gpu_id")) if torch.cuda.is_available() and properties.get("gpu_id") is not None else "cpu" self.device = torch.device(device_str) self.pipe = StableDiffusionXLPipeline.from_pretrained( "./", torch_dtype=torch.float16, use_safetensors=True, ) logger.info("moving model to device: %s", device_str) self.pipe.to(self.device) logger.info(self.device) logger.info("Diffusion model from path %s loaded successfully", model_dir) self.initialized = True def preprocess(self, raw_requests): """Basic text preprocessing, of the user's prompt. Args: requests (str): The Input data in the form of text is passed on to the preprocess function. Returns: list : The preprocess function returns a list of prompts. """ logger.info("Received requests: '%s'", raw_requests) processed_request = { "prompt": raw_requests[0]["prompt"], "negative_prompt": raw_requests[0].get("negative_prompt"), "width": raw_requests[0].get("width"), "height": raw_requests[0].get("height"), "num_inference_steps": raw_requests[0].get("num_inference_steps", 30), "guidance_scale": raw_requests[0].get("guidance_scale", 7.5), } logger.info("Processed request: '%s'", processed_request) return processed_request def inference(self, request): """Generates the image relevant to the received text. Args: inputs (list): List of Text from the pre-process function is passed here Returns: list : It returns a list of the generate images for the input text """ # Handling inference for sequence_classification. inferences = self.pipe( **request ).images logger.info("Generated image: '%s'", inferences) return inferences def postprocess(self, inference_outputs): """Post Process Function converts the generated image into Torchserve readable format. Args: inference_outputs (list): It contains the generated image of the input text. Returns: (list): Returns a list of the images. """ bucket_name = "outputs-storage-prod" client = storage.Client() bucket = client.get_bucket(bucket_name) outputs = [] for image in inference_outputs: image_name = str(uuid.uuid4()) blob = bucket.blob(image_name + '.png') with TemporaryFile() as tmp: image.save(tmp, format="png") tmp.seek(0) blob.upload_from_file(tmp, content_type='image/png') # generate txt file with the image name and the prompt inside # blob = bucket.blob(image_name + '.txt') # blob.upload_from_string(self.prompt) outputs.append('https://storage.googleapis.com/' + bucket_name + '/' + image_name + '.png') return outputs