conroy-test / handler.py
pwaldron's picture
Update handler.py
30e8f60 verified
from typing import Dict, List, Any
import torch
import base64
from PIL import Image
from io import BytesIO
from diffusers import T2IAdapter, StableDiffusionXLAdapterPipeline, StableDiffusionXLImg2ImgPipeline, AutoencoderKL, DPMSolverMultistepScheduler
from controlnet_aux.pidi import PidiNetDetector
# set device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
if device.type != 'cuda':
raise ValueError("need to run on GPU")
class EndpointHandler():
# Preload all the elements you are going to need at inference.
def __init__(self, path=""):
# load the T2I adapter
adapter = T2IAdapter.from_pretrained(
"Adapter/t2iadapter",
subfolder="sketch_sdxl_1.0",
torch_dtype=torch.float16,
adapter_type="full_adapter_xl",
use_safetensors=True,
)
# load variational autoencoder (VAE)
vae = AutoencoderKL.from_pretrained(
"madebyollin/sdxl-vae-fp16-fix",
torch_dtype=torch.float16,
use_safetensors=True,
)
# load the scheduler
scheduler = DPMSolverMultistepScheduler.from_pretrained(
"stabilityai/stable-diffusion-xl-base-1.0",
subfolder="scheduler",
use_lu_lambdas=True,
euler_at_final=True,
)
# instantiate HF pipeline to combine all the components
self.pipeline = StableDiffusionXLAdapterPipeline.from_pretrained(
"stabilityai/stable-diffusion-xl-base-1.0",
adapter=adapter,
vae=vae,
scheduler=scheduler,
torch_dtype=torch.float16,
variant="fp16",
use_safetensors=True,
).to("cuda")
# instantiate HF refiner to improve output image
self.refiner = StableDiffusionXLImg2ImgPipeline.from_pretrained(
"stabilityai/stable-diffusion-xl-refiner-1.0",
text_encoder_2=self.pipeline.text_encoder_2,
vae=vae,
torch_dtype=torch.float16,
variant="fp16",
use_safetensors=True,
).to("cuda")
self.pipeline.enable_model_cpu_offload()
self.refiner.enable_model_cpu_offload()
self.pidinet = PidiNetDetector.from_pretrained("lllyasviel/Annotators").to("cuda")
def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]:
"""
data args:
inputs (:obj: `str` | `PIL.Image` | `np.array`)
kwargs
Return:
A :obj:`list` | `dict`: will be serialized and returned
"""
# pseudo
# self.model(input)
# get inputs
inputs = data.pop("inputs", "")
encoded_image = data.pop("image", None)
adapter_conditioning_scale = data.pop("adapter_conditioning_scale", 1.0)
adapter_conditioning_factor = data.pop("adapter_conditioning_factor", 1.0)
# Decode image and convert to black and white sketch
decoded_image = self.decode_base64_image(encoded_image).convert('RGB')
sketch_image = self.pidinet(
decoded_image,
detect_resolution=1024,
image_resolution=1024,
apply_filter=True
).convert('L')
# sketch_image.save("./output1.png")
num_inference_steps = 25
high_noise_frac = 0.7
base_image = self.pipeline(
prompt=inputs,
negative_prompt="extra digit, fewer digits, cropped, worst quality, low quality",
image=sketch_image,
num_inference_steps=num_inference_steps,
denoising_end=high_noise_frac,
guidance_scale=7.5,
adapter_conditioning_scale=adapter_conditioning_scale,
adapter_conditioning_factor=adapter_conditioning_factor,
output_type="latent",
).images
output_image = self.refiner(
prompt=inputs,
negative_prompt="extra digit, fewer digits, cropped, worst quality, low quality",
image=base_image,
num_inference_steps=num_inference_steps,
denoising_start=high_noise_frac,
guidance_scale=7.5,
adapter_conditioning_scale=adapter_conditioning_scale,
adapter_conditioning_factor=adapter_conditioning_factor,
).images[0]
# output_image.save("./output2.png")
return output_image
# helper to decode input image
def decode_base64_image(self, image_string):
base64_image = base64.b64decode(image_string)
buffer = BytesIO(base64_image)
image = Image.open(buffer)
return image