from typing import Dict, List, Any import torch import base64 from PIL import Image from io import BytesIO from diffusers import T2IAdapter, StableDiffusionXLAdapterPipeline, StableDiffusionXLImg2ImgPipeline, AutoencoderKL, DPMSolverMultistepScheduler from controlnet_aux.pidi import PidiNetDetector # set device device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') if device.type != 'cuda': raise ValueError("need to run on GPU") class EndpointHandler(): # Preload all the elements you are going to need at inference. def __init__(self, path=""): # load the T2I adapter adapter = T2IAdapter.from_pretrained( "Adapter/t2iadapter", subfolder="sketch_sdxl_1.0", torch_dtype=torch.float16, adapter_type="full_adapter_xl", use_safetensors=True, ) # load variational autoencoder (VAE) vae = AutoencoderKL.from_pretrained( "madebyollin/sdxl-vae-fp16-fix", torch_dtype=torch.float16, use_safetensors=True, ) # load the scheduler scheduler = DPMSolverMultistepScheduler.from_pretrained( "stabilityai/stable-diffusion-xl-base-1.0", subfolder="scheduler", use_lu_lambdas=True, euler_at_final=True, ) # instantiate HF pipeline to combine all the components self.pipeline = StableDiffusionXLAdapterPipeline.from_pretrained( "stabilityai/stable-diffusion-xl-base-1.0", adapter=adapter, vae=vae, scheduler=scheduler, torch_dtype=torch.float16, variant="fp16", use_safetensors=True, ).to("cuda") # instantiate HF refiner to improve output image self.refiner = StableDiffusionXLImg2ImgPipeline.from_pretrained( "stabilityai/stable-diffusion-xl-refiner-1.0", text_encoder_2=self.pipeline.text_encoder_2, vae=vae, torch_dtype=torch.float16, variant="fp16", use_safetensors=True, ).to("cuda") self.pipeline.enable_model_cpu_offload() self.refiner.enable_model_cpu_offload() self.pidinet = PidiNetDetector.from_pretrained("lllyasviel/Annotators").to("cuda") def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]: """ data args: inputs (:obj: `str` | `PIL.Image` | `np.array`) kwargs Return: A :obj:`list` | `dict`: will be serialized and returned """ # pseudo # self.model(input) # get inputs inputs = data.pop("inputs", "") encoded_image = data.pop("image", None) adapter_conditioning_scale = data.pop("adapter_conditioning_scale", 1.0) adapter_conditioning_factor = data.pop("adapter_conditioning_factor", 1.0) # Decode image and convert to black and white sketch decoded_image = self.decode_base64_image(encoded_image).convert('RGB') sketch_image = self.pidinet( decoded_image, detect_resolution=1024, image_resolution=1024, apply_filter=True ).convert('L') # sketch_image.save("./output1.png") num_inference_steps = 25 high_noise_frac = 0.7 base_image = self.pipeline( prompt=inputs, negative_prompt="extra digit, fewer digits, cropped, worst quality, low quality", image=sketch_image, num_inference_steps=num_inference_steps, denoising_end=high_noise_frac, guidance_scale=7.5, adapter_conditioning_scale=adapter_conditioning_scale, adapter_conditioning_factor=adapter_conditioning_factor, output_type="latent", ).images output_image = self.refiner( prompt=inputs, negative_prompt="extra digit, fewer digits, cropped, worst quality, low quality", image=base_image, num_inference_steps=num_inference_steps, denoising_start=high_noise_frac, guidance_scale=7.5, adapter_conditioning_scale=adapter_conditioning_scale, adapter_conditioning_factor=adapter_conditioning_factor, ).images[0] # output_image.save("./output2.png") return output_image # helper to decode input image def decode_base64_image(self, image_string): base64_image = base64.b64decode(image_string) buffer = BytesIO(base64_image) image = Image.open(buffer) return image