File size: 6,470 Bytes
fae0531
bec7879
 
 
fae0531
bec7879
 
 
 
 
 
 
d4747d7
 
 
 
 
 
fae0531
8aaec62
 
 
 
fae0531
 
 
8aaec62
fae0531
b9fbc24
d4747d7
b9fbc24
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
fae0531
 
bec7879
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b9fbc24
fae0531
d4747d7
fae0531
 
8aaec62
a1f58f2
fae0531
 
a1f58f2
d4747d7
 
bec7879
 
 
 
 
 
 
 
 
fae0531
bec7879
fae0531
bec7879
 
 
fae0531
4cc0dca
 
fae0531
bec7879
 
fae0531
 
bec7879
d4747d7
 
bec7879
fae0531
bec7879
 
 
 
b9fbc24
 
fae0531
466101e
fae0531
bec7879
 
fae0531
 
 
bec7879
662553c
fae0531
bec7879
 
b9fbc24
 
fae0531
bec7879
4cc0dca
466101e
bec7879
662553c
fae0531
662553c
bec7879
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
from typing import Dict, List, Any
import base64
from PIL import Image
from io import BytesIO
from diffusers import StableDiffusionXLControlNetPipeline, ControlNetModel, AutoencoderKL, UniPCMultistepScheduler
import torch


import numpy as np
import cv2
import controlnet_hinter

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
if device.type != 'cuda':
    raise ValueError("need to run on GPU")
# set mixed precision dtype
dtype = torch.bfloat16 if torch.cuda.get_device_capability()[
    0] == 8 else torch.float16


SDXL_CACHE = "./sdxl-cache"
CN_CACHE = "./cn-canny-edge-cache"

# for the moment, support only canny edge
SDXLCONTROLNET_MAPPING = {
    "canny_edge": {
        "model_id": CN_CACHE,
        "hinter": controlnet_hinter.hint_canny
    },
    "pose": {
        "model_id": "lllyasviel/sd-controlnet-openpose",
        "hinter": controlnet_hinter.hint_openpose
    },
    "depth": {
        "model_id": "lllyasviel/sd-controlnet-depth",
        "hinter": controlnet_hinter.hint_depth
    },
    "scribble": {
        "model_id": "lllyasviel/sd-controlnet-scribble",
        "hinter": controlnet_hinter.hint_scribble,
    },
    "segmentation": {
        "model_id": "lllyasviel/sd-controlnet-seg",
        "hinter": controlnet_hinter.hint_segmentation,
    },
    "normal": {
        "model_id": "lllyasviel/sd-controlnet-normal",
        "hinter": controlnet_hinter.hint_normal,
    },
    "hed": {
        "model_id": "lllyasviel/sd-controlnet-hed",
        "hinter": controlnet_hinter.hint_hed,
    },
    "hough": {
        "model_id": "lllyasviel/sd-controlnet-mlsd",
        "hinter": controlnet_hinter.hint_hough,
    }
}

# controlnet mapping for controlnet id and control hinter
CONTROLNET_MAPPING = {
    "canny_edge": {
        "model_id": "lllyasviel/sd-controlnet-canny",
        "hinter": controlnet_hinter.hint_canny
    },
    "pose": {
        "model_id": "lllyasviel/sd-controlnet-openpose",
        "hinter": controlnet_hinter.hint_openpose
    },
    "depth": {
        "model_id": "lllyasviel/sd-controlnet-depth",
        "hinter": controlnet_hinter.hint_depth
    },
    "scribble": {
        "model_id": "lllyasviel/sd-controlnet-scribble",
        "hinter": controlnet_hinter.hint_scribble,
    },
    "segmentation": {
        "model_id": "lllyasviel/sd-controlnet-seg",
        "hinter": controlnet_hinter.hint_segmentation,
    },
    "normal": {
        "model_id": "lllyasviel/sd-controlnet-normal",
        "hinter": controlnet_hinter.hint_normal,
    },
    "hed": {
        "model_id": "lllyasviel/sd-controlnet-hed",
        "hinter": controlnet_hinter.hint_hed,
    },
    "hough": {
        "model_id": "lllyasviel/sd-controlnet-mlsd",
        "hinter": controlnet_hinter.hint_hough,
    }
}


class EndpointHandler():
    def __init__(self, path=""):
        # define default controlnet id and load controlnet
        self.control_type = "canny_edge"
        self.controlnet = ControlNetModel.from_pretrained(
            SDXLCONTROLNET_MAPPING[self.control_type]["model_id"], torch_dtype=dtype).to(device)

        # Load StableDiffusionControlNetPipeline
        self.sdxl_id = SDXL_CACHE
        # self.stable_diffusion_id = "runwayml/stable-diffusion-v1-5"
        self.pipe = StableDiffusionXLControlNetPipeline.from_pretrained(self.sdxl_id,
                                                                        controlnet=self.controlnet,
                                                                        # vae = AutoencoderKL.from_pretrained("madebyollin/sdxl-vae-fp16-fix", torch_dtype=torch.float16, use_safetensors=True),
                                                                        torch_dtype=dtype,
                                                                        safety_checker=None).to(device)
        self.generator = torch.Generator(device="cpu").manual_seed(3)

    def __call__(self, data: Any) -> List[List[Dict[str, float]]]:
        """
        :param data: A dictionary contains `inputs` and optional `image` field.
        :return: A dictionary with `image` field contains image in base64.
        """
        prompt = data.pop("inputs", None)
        image = data.pop("image", None)
        num_of_images = data.pop("num_of_images", None)
        controlnet_type = data.pop("controlnet_type", None)

        # Check if neither prompt nor image is provided
        if prompt is None and image is None:
            return {"error": "Please provide a prompt and base64 encoded image."}

        if num_of_images is None:
            num_of_images = 1

        # Check if a new controlnet is provided
        if controlnet_type is not None and controlnet_type != self.control_type:
            print(
                f"changing controlnet from {self.control_type} to {controlnet_type} using {SDXLCONTROLNET_MAPPING[controlnet_type]['model_id']} model")
            self.control_type = controlnet_type
            self.controlnet = ControlNetModel.from_pretrained(
                SDXLCONTROLNET_MAPPING[self.control_type]["model_id"], torch_dtype=dtype).to(device)
            self.pipe.controlnet = self.controlnet

        # hyperparamters
        num_inference_steps = data.pop("num_inference_steps", 30)
        guidance_scale = data.pop("guidance_scale", 7.5)
        negative_prompt = data.pop("negative_prompt", None)
        height = data.pop("height", 1024)
        width = data.pop("width", 1024)
        controlnet_conditioning_scale = data.pop(
            "controlnet_conditioning_scale", 1.0)

        # process image
        image = self.decode_base64_image(image)
        control_image = SDXLCONTROLNET_MAPPING[self.control_type]["hinter"](
            image, width=1024, height=1024)

        # run inference pipeline
        out = self.pipe(
            prompt=prompt,
            negative_prompt=negative_prompt,
            image=control_image,
            # width=width,
            # height=height,
            num_inference_steps=num_inference_steps,
            guidance_scale=guidance_scale,
            num_images_per_prompt=num_of_images,
            controlnet_conditioning_scale=controlnet_conditioning_scale,
            generator=self.generator
        )

        return out.images

    # helper to decode input image
    def decode_base64_image(self, image_string):
        base64_image = base64.b64decode(image_string)
        buffer = BytesIO(base64_image)
        image = Image.open(buffer)
        return image