from compel import Compel, ReturnedEmbeddingsType import logging from abc import ABC import diffusers import torch from diffusers import StableDiffusionXLPipeline import numpy as np import threading import base64 from io import BytesIO from PIL import Image import numpy as np import uuid from tempfile import TemporaryFile from google.cloud import storage import sys from flask import Flask, request, jsonify logger = logging.getLogger(__name__) logger.info("Diffusers version %s", diffusers.__version__) class DiffusersHandler(ABC): """ Diffusers handler class for text to image generation. """ def __init__(self): self.initialized = False def initialize(self, properties): """In this initialize function, the Stable Diffusion model is loaded and initialized here. Args: ctx (context): It is a JSON Object containing information pertaining to the model artefacts parameters. """ logger.info("Loading diffusion model") logger.info("I'm totally new and updated") device_str = "cuda:" + str(properties.get("gpu_id")) if torch.cuda.is_available() and properties.get("gpu_id") is not None else "cpu" print("my device is " + device_str) self.device = torch.device(device_str) self.pipe = StableDiffusionXLPipeline.from_pretrained( sys.argv[1], torch_dtype=torch.float16, use_safetensors=True, ) logger.info("moving model to device: %s", device_str) self.pipe.to(self.device) logger.info(self.device) logger.info("Diffusion model from path %s loaded successfully") self.initialized = True def preprocess(self, raw_requests): """Basic text preprocessing, of the user's prompt. Args: requests (str): The Input data in the form of text is passed on to the preprocess function. Returns: list : The preprocess function returns a list of prompts. """ logger.info("Received requests: '%s'", raw_requests) self.working = True processed_request = { "prompt": raw_requests[0]["prompt"], "negative_prompt": raw_requests[0].get("negative_prompt"), "width": raw_requests[0].get("width"), "height": raw_requests[0].get("height"), "num_inference_steps": raw_requests[0].get("num_inference_steps", 30), "guidance_scale": raw_requests[0].get("guidance_scale", 7.5), } logger.info("Processed request: '%s'", processed_request) return processed_request def inference(self, request): """Generates the image relevant to the received text. Args: inputs (list): List of Text from the pre-process function is passed here Returns: list : It returns a list of the generate images for the input text """ # Handling inference for sequence_classification. compel = Compel(tokenizer=[self.pipe.tokenizer, self.pipe.tokenizer_2] , text_encoder=[self.pipe.text_encoder, self.pipe.text_encoder_2], returned_embeddings_type=ReturnedEmbeddingsType.PENULTIMATE_HIDDEN_STATES_NON_NORMALIZED, requires_pooled=[False, True]) self.prompt = request.pop("prompt") conditioning, pooled = compel(self.prompt) # Handling inference for sequence_classification. inferences = self.pipe( prompt_embeds=conditioning, pooled_prompt_embeds=pooled, **request ).images logger.info("Generated image: '%s'", inferences) return inferences def postprocess(self, inference_outputs): """Post Process Function converts the generated image into Torchserve readable format. Args: inference_outputs (list): It contains the generated image of the input text. Returns: (list): Returns a list of the images. """ bucket_name = "outputs-storage-prod" client = storage.Client() self.working = False bucket = client.get_bucket(bucket_name) outputs = [] for image in inference_outputs: image_name = str(uuid.uuid4()) blob = bucket.blob(image_name + '.png') with TemporaryFile() as tmp: image.save(tmp, format="png") tmp.seek(0) blob.upload_from_file(tmp, content_type='image/png') # generate txt file with the image name and the prompt inside # blob = bucket.blob(image_name + '.txt') # blob.upload_from_string(self.prompt) outputs.append('https://storage.googleapis.com/' + bucket_name + '/' + image_name + '.png') return outputs app = Flask(__name__) # Initialize the handler on startup gpu_count = torch.cuda.device_count() if gpu_count == 0: raise ValueError("No GPUs available!") handlers = [DiffusersHandler() for i in range(gpu_count)] for i in range(gpu_count): handlers[i].initialize({"gpu_id": i}) handler_lock = threading.Lock() handler_index = 0 @app.route('/generate', methods=['POST']) def generate_image(): global handler_index try: # Extract raw requests from HTTP POST body raw_requests = request.json with handler_lock: selected_handler = handlers[handler_index] handler_index = (handler_index + 1) % gpu_count # Rotate to the next handler processed_request = selected_handler.preprocess([raw_requests]) inferences = selected_handler.inference(processed_request) outputs = selected_handler.postprocess(inferences) return jsonify({"image_urls": outputs}) except Exception as e: logger.error("Error during image generation: %s", str(e)) return jsonify({"error": "Failed to generate image", "details": str(e)}), 500 if __name__ == '__main__': app.run(host='0.0.0.0', port=3000, threaded=True)