File size: 4,644 Bytes
aca26e9 aeba19d aca26e9 7010589 aca26e9 7010589 aca26e9 aeba19d aca26e9 e93d2b4 aca26e9 ff6535b aca26e9 30e8f60 9b92fa5 aca26e9 9b92fa5 aca26e9 9b92fa5 aca26e9 8330e27 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 |
from typing import Dict, List, Any
import torch
import base64
from PIL import Image
from io import BytesIO
from diffusers import T2IAdapter, StableDiffusionXLAdapterPipeline, StableDiffusionXLImg2ImgPipeline, AutoencoderKL, DPMSolverMultistepScheduler
from controlnet_aux.pidi import PidiNetDetector
# set device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
if device.type != 'cuda':
raise ValueError("need to run on GPU")
class EndpointHandler():
# Preload all the elements you are going to need at inference.
def __init__(self, path=""):
# load the T2I adapter
adapter = T2IAdapter.from_pretrained(
"Adapter/t2iadapter",
subfolder="sketch_sdxl_1.0",
torch_dtype=torch.float16,
adapter_type="full_adapter_xl",
use_safetensors=True,
)
# load variational autoencoder (VAE)
vae = AutoencoderKL.from_pretrained(
"madebyollin/sdxl-vae-fp16-fix",
torch_dtype=torch.float16,
use_safetensors=True,
)
# load the scheduler
scheduler = DPMSolverMultistepScheduler.from_pretrained(
"stabilityai/stable-diffusion-xl-base-1.0",
subfolder="scheduler",
use_lu_lambdas=True,
euler_at_final=True,
)
# instantiate HF pipeline to combine all the components
self.pipeline = StableDiffusionXLAdapterPipeline.from_pretrained(
"stabilityai/stable-diffusion-xl-base-1.0",
adapter=adapter,
vae=vae,
scheduler=scheduler,
torch_dtype=torch.float16,
variant="fp16",
use_safetensors=True,
).to("cuda")
# instantiate HF refiner to improve output image
self.refiner = StableDiffusionXLImg2ImgPipeline.from_pretrained(
"stabilityai/stable-diffusion-xl-refiner-1.0",
text_encoder_2=self.pipeline.text_encoder_2,
vae=vae,
torch_dtype=torch.float16,
variant="fp16",
use_safetensors=True,
).to("cuda")
self.pipeline.enable_model_cpu_offload()
self.refiner.enable_model_cpu_offload()
self.pidinet = PidiNetDetector.from_pretrained("lllyasviel/Annotators").to("cuda")
def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]:
"""
data args:
inputs (:obj: `str` | `PIL.Image` | `np.array`)
kwargs
Return:
A :obj:`list` | `dict`: will be serialized and returned
"""
# pseudo
# self.model(input)
# get inputs
inputs = data.pop("inputs", "")
encoded_image = data.pop("image", None)
adapter_conditioning_scale = data.pop("adapter_conditioning_scale", 1.0)
adapter_conditioning_factor = data.pop("adapter_conditioning_factor", 1.0)
# Decode image and convert to black and white sketch
decoded_image = self.decode_base64_image(encoded_image).convert('RGB')
sketch_image = self.pidinet(
decoded_image,
detect_resolution=1024,
image_resolution=1024,
apply_filter=True
).convert('L')
# sketch_image.save("./output1.png")
num_inference_steps = 25
high_noise_frac = 0.7
base_image = self.pipeline(
prompt=inputs,
negative_prompt="extra digit, fewer digits, cropped, worst quality, low quality",
image=sketch_image,
num_inference_steps=num_inference_steps,
denoising_end=high_noise_frac,
guidance_scale=7.5,
adapter_conditioning_scale=adapter_conditioning_scale,
adapter_conditioning_factor=adapter_conditioning_factor,
output_type="latent",
).images
output_image = self.refiner(
prompt=inputs,
negative_prompt="extra digit, fewer digits, cropped, worst quality, low quality",
image=base_image,
num_inference_steps=num_inference_steps,
denoising_start=high_noise_frac,
guidance_scale=7.5,
adapter_conditioning_scale=adapter_conditioning_scale,
adapter_conditioning_factor=adapter_conditioning_factor,
).images[0]
# output_image.save("./output2.png")
return output_image
# helper to decode input image
def decode_base64_image(self, image_string):
base64_image = base64.b64decode(image_string)
buffer = BytesIO(base64_image)
image = Image.open(buffer)
return image |