from typing import Dict, List, Any import torch from PIL import Image from io import BytesIO from diffusers import StableDiffusionPipeline, StableDiffusionImg2ImgPipeline, DDIMScheduler from transformers.utils import logging import base64 import requests from io import BytesIO from PIL import Image logging.set_verbosity_info() logger = logging.get_logger("transformers") def load_image(image_url): if image_url.startswith('data:'): # Decode base64 data_uri image_data = base64.b64decode(image_url.split(',')[1]) image = Image.open(BytesIO(image_data)) else: # Load standard image url response = requests.get(image_url) image = Image.open(BytesIO(response.content)) return image # set device device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') if device.type != 'cuda': raise ValueError("need to run on GPU") model_id = "stabilityai/stable-diffusion-2-1-base" class EndpointHandler(): def __init__(self, path=""): # load the optimized model self.textPipe = StableDiffusionPipeline.from_pretrained(model_id, torch_dtype=torch.float16) self.textPipe.scheduler = DDIMScheduler.from_config(self.textPipe.scheduler.config) self.textPipe = self.textPipe.to(device) # create an img2img model self.imgPipe = StableDiffusionImg2ImgPipeline.from_pretrained(model_id, torch_dtype=torch.float16) self.imgPipe.scheduler = DDIMScheduler.from_config(self.imgPipe.scheduler.config) self.imgPipe = self.imgPipe.to(device) def __call__(self, data: Any) -> List[List[Dict[str, float]]]: """ Args: data (:obj:): includes the input data and the parameters for the inference. Return: A :obj:`dict`:. base64 encoded image """ prompt = data.pop("inputs", data) url = data.pop("url", data) init_image = load_image(url).convert("RGB") init_image.thumbnail((512, 512)) params = data.pop("parameters", data) # hyperparamters num_inference_steps = params.pop("num_inference_steps", 25) guidance_scale = params.pop("guidance_scale", 7.5) negative_prompt = params.pop("negative_prompt", None) height = params.pop("height", None) strength = params.pop("strength", 0.8) width = params.pop("width", None) manual_seed = params.pop("manual_seed", -1) logger.info(f"strength: {strength}, manual_seed: {manual_seed}, inference_steps: {num_inference_steps}, guidance_scale: {guidance_scale}") out = None generator = torch.Generator(device='cuda') generator.manual_seed(manual_seed) # run img2img pipeline out = self.imgPipe(prompt, image=init_image, strength=strength, num_inference_steps=num_inference_steps, guidance_scale=guidance_scale, num_images_per_prompt=1, negative_prompt=negative_prompt, # height=height, # width=width ) # return first generated PIL image return out.images[0]