File size: 6,470 Bytes
fae0531 bec7879 fae0531 bec7879 d4747d7 fae0531 8aaec62 fae0531 8aaec62 fae0531 b9fbc24 d4747d7 b9fbc24 fae0531 bec7879 b9fbc24 fae0531 d4747d7 fae0531 8aaec62 a1f58f2 fae0531 a1f58f2 d4747d7 bec7879 fae0531 bec7879 fae0531 bec7879 fae0531 4cc0dca fae0531 bec7879 fae0531 bec7879 d4747d7 bec7879 fae0531 bec7879 b9fbc24 fae0531 466101e fae0531 bec7879 fae0531 bec7879 662553c fae0531 bec7879 b9fbc24 fae0531 bec7879 4cc0dca 466101e bec7879 662553c fae0531 662553c bec7879 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 |
from typing import Dict, List, Any
import base64
from PIL import Image
from io import BytesIO
from diffusers import StableDiffusionXLControlNetPipeline, ControlNetModel, AutoencoderKL, UniPCMultistepScheduler
import torch
import numpy as np
import cv2
import controlnet_hinter
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
if device.type != 'cuda':
raise ValueError("need to run on GPU")
# set mixed precision dtype
dtype = torch.bfloat16 if torch.cuda.get_device_capability()[
0] == 8 else torch.float16
SDXL_CACHE = "./sdxl-cache"
CN_CACHE = "./cn-canny-edge-cache"
# for the moment, support only canny edge
SDXLCONTROLNET_MAPPING = {
"canny_edge": {
"model_id": CN_CACHE,
"hinter": controlnet_hinter.hint_canny
},
"pose": {
"model_id": "lllyasviel/sd-controlnet-openpose",
"hinter": controlnet_hinter.hint_openpose
},
"depth": {
"model_id": "lllyasviel/sd-controlnet-depth",
"hinter": controlnet_hinter.hint_depth
},
"scribble": {
"model_id": "lllyasviel/sd-controlnet-scribble",
"hinter": controlnet_hinter.hint_scribble,
},
"segmentation": {
"model_id": "lllyasviel/sd-controlnet-seg",
"hinter": controlnet_hinter.hint_segmentation,
},
"normal": {
"model_id": "lllyasviel/sd-controlnet-normal",
"hinter": controlnet_hinter.hint_normal,
},
"hed": {
"model_id": "lllyasviel/sd-controlnet-hed",
"hinter": controlnet_hinter.hint_hed,
},
"hough": {
"model_id": "lllyasviel/sd-controlnet-mlsd",
"hinter": controlnet_hinter.hint_hough,
}
}
# controlnet mapping for controlnet id and control hinter
CONTROLNET_MAPPING = {
"canny_edge": {
"model_id": "lllyasviel/sd-controlnet-canny",
"hinter": controlnet_hinter.hint_canny
},
"pose": {
"model_id": "lllyasviel/sd-controlnet-openpose",
"hinter": controlnet_hinter.hint_openpose
},
"depth": {
"model_id": "lllyasviel/sd-controlnet-depth",
"hinter": controlnet_hinter.hint_depth
},
"scribble": {
"model_id": "lllyasviel/sd-controlnet-scribble",
"hinter": controlnet_hinter.hint_scribble,
},
"segmentation": {
"model_id": "lllyasviel/sd-controlnet-seg",
"hinter": controlnet_hinter.hint_segmentation,
},
"normal": {
"model_id": "lllyasviel/sd-controlnet-normal",
"hinter": controlnet_hinter.hint_normal,
},
"hed": {
"model_id": "lllyasviel/sd-controlnet-hed",
"hinter": controlnet_hinter.hint_hed,
},
"hough": {
"model_id": "lllyasviel/sd-controlnet-mlsd",
"hinter": controlnet_hinter.hint_hough,
}
}
class EndpointHandler():
def __init__(self, path=""):
# define default controlnet id and load controlnet
self.control_type = "canny_edge"
self.controlnet = ControlNetModel.from_pretrained(
SDXLCONTROLNET_MAPPING[self.control_type]["model_id"], torch_dtype=dtype).to(device)
# Load StableDiffusionControlNetPipeline
self.sdxl_id = SDXL_CACHE
# self.stable_diffusion_id = "runwayml/stable-diffusion-v1-5"
self.pipe = StableDiffusionXLControlNetPipeline.from_pretrained(self.sdxl_id,
controlnet=self.controlnet,
# vae = AutoencoderKL.from_pretrained("madebyollin/sdxl-vae-fp16-fix", torch_dtype=torch.float16, use_safetensors=True),
torch_dtype=dtype,
safety_checker=None).to(device)
self.generator = torch.Generator(device="cpu").manual_seed(3)
def __call__(self, data: Any) -> List[List[Dict[str, float]]]:
"""
:param data: A dictionary contains `inputs` and optional `image` field.
:return: A dictionary with `image` field contains image in base64.
"""
prompt = data.pop("inputs", None)
image = data.pop("image", None)
num_of_images = data.pop("num_of_images", None)
controlnet_type = data.pop("controlnet_type", None)
# Check if neither prompt nor image is provided
if prompt is None and image is None:
return {"error": "Please provide a prompt and base64 encoded image."}
if num_of_images is None:
num_of_images = 1
# Check if a new controlnet is provided
if controlnet_type is not None and controlnet_type != self.control_type:
print(
f"changing controlnet from {self.control_type} to {controlnet_type} using {SDXLCONTROLNET_MAPPING[controlnet_type]['model_id']} model")
self.control_type = controlnet_type
self.controlnet = ControlNetModel.from_pretrained(
SDXLCONTROLNET_MAPPING[self.control_type]["model_id"], torch_dtype=dtype).to(device)
self.pipe.controlnet = self.controlnet
# hyperparamters
num_inference_steps = data.pop("num_inference_steps", 30)
guidance_scale = data.pop("guidance_scale", 7.5)
negative_prompt = data.pop("negative_prompt", None)
height = data.pop("height", 1024)
width = data.pop("width", 1024)
controlnet_conditioning_scale = data.pop(
"controlnet_conditioning_scale", 1.0)
# process image
image = self.decode_base64_image(image)
control_image = SDXLCONTROLNET_MAPPING[self.control_type]["hinter"](
image, width=1024, height=1024)
# run inference pipeline
out = self.pipe(
prompt=prompt,
negative_prompt=negative_prompt,
image=control_image,
# width=width,
# height=height,
num_inference_steps=num_inference_steps,
guidance_scale=guidance_scale,
num_images_per_prompt=num_of_images,
controlnet_conditioning_scale=controlnet_conditioning_scale,
generator=self.generator
)
return out.images
# helper to decode input image
def decode_base64_image(self, image_string):
base64_image = base64.b64decode(image_string)
buffer = BytesIO(base64_image)
image = Image.open(buffer)
return image
|