File size: 6,798 Bytes
a24ff22
 
 
 
3eb4425
 
de322fa
c19e295
a0c9311
01a43d2
de322fa
 
a24ff22
 
 
 
 
 
 
 
 
212c508
186873f
a24ff22
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b674557
4fb1dc1
b674557
a24ff22
 
3a7e99d
 
a923e3f
 
 
 
 
 
 
 
bacee68
 
 
 
b674557
bacee68
 
0dda0de
 
b674557
 
d0be649
a24ff22
b891004
a24ff22
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2311e69
a24ff22
 
 
539b699
a24ff22
 
 
 
b674557
a24ff22
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
from typing import  Dict, List, Any
import base64
from PIL import Image
from io import BytesIO
from diffusers import StableDiffusionControlNetPipeline, ControlNetModel
#from diffusers import StableDiffusionControlNetPipeline, ControlNetModel, StableDiffusionSafetyChecker
# import Safety Checker
from diffusers.pipelines.stable_diffusion import StableDiffusionSafetyChecker

import torch


import numpy as np
import cv2
import controlnet_hinter

# set device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
if device.type != 'cuda':
    raise ValueError("need to run on GPU")
# set mixed precision dtype
dtype = torch.bfloat16 if torch.cuda.get_device_capability()[0] == 8 else torch.float16
    
# controlnet mapping for controlnet id and control hinter
CONTROLNET_MAPPING = {
    "canny_edge": {
        "model_id": "lllyasviel/sd-controlnet-canny",
        "hinter": controlnet_hinter.hint_canny
    },
    "pose": {
        "model_id": "lllyasviel/sd-controlnet-openpose",
        "hinter": controlnet_hinter.hint_openpose
    },
    "depth": {
        "model_id": "lllyasviel/sd-controlnet-depth",
        "hinter": controlnet_hinter.hint_depth
    },
    "scribble": {
        "model_id": "lllyasviel/sd-controlnet-scribble",
        "hinter": controlnet_hinter.hint_scribble,
    },
    "segmentation": {
        "model_id": "lllyasviel/sd-controlnet-seg",
        "hinter": controlnet_hinter.hint_segmentation,
    },
    "normal": {
        "model_id": "lllyasviel/sd-controlnet-normal",
        "hinter": controlnet_hinter.hint_normal,
    },
    "hed": {
        "model_id": "lllyasviel/sd-controlnet-hed",
        "hinter": controlnet_hinter.hint_hed,
    },
    "hough": {
        "model_id": "lllyasviel/sd-controlnet-mlsd",
        "hinter": controlnet_hinter.hint_hough,
    }
}


class EndpointHandler():
    def __init__(self, path=""):
        # define default controlnet id and load controlnet
        self.control_type = "depth"
        self.controlnet = ControlNetModel.from_pretrained(CONTROLNET_MAPPING[self.control_type]["model_id"],torch_dtype=dtype).to(device)

        #processor = AutoProcessor.from_pretrained("CompVis/stable-diffusion-safety-checker")

        
        # Load StableDiffusionControlNetPipeline 
        #self.stable_diffusion_id = "runwayml/stable-diffusion-v1-5"
        self.stable_diffusion_id = "Lykon/dreamshaper-8"
#        self.pipe = StableDiffusionControlNetPipeline.from_pretrained(self.stable_diffusion_id, 
#                                                                      controlnet=self.controlnet, 
#                                                                      torch_dtype=dtype,
#                                                                      #safety_checker=None).to(device)
#                                                                      #processor = AutoProcessor.from_pretrained("CompVis/stable-diffusion-safety-checker")
#                                                                      #safety_checker = SafetyChecker.from_pretrained("CompVis/stable-diffusion-safety-checker")
#                                                                      safety_checker = StableDiffusionSafetyChecker.from_pretrained("CompVis/stable-diffusion-safety-checker")

#        self.pipe = StableDiffusionControlNetPipeline.from_pretrained(
#            self.stable_diffusion_id, 
#            controlnet=self.controlnet, 
#            torch_dtype=dtype,
#            safety_checker = SafetyChecker.from_pretrained("CompVis/stable-diffusion-safety-checker")
#        ).to(device)


        self.pipe = StableDiffusionControlNetPipeline.from_pretrained(self.stable_diffusion_id,
                                                                      controlnet=self.controlnet, 
                                                                      torch_dtype=dtype,
                                                                      safety_checker=StableDiffusionSafetyChecker.from_pretrained("CompVis/stable-diffusion-safety-checker", torch_dtype=dtype)).to("cuda")
        # Define Generator with seed
        self.generator = torch.Generator(device=device.type).manual_seed(3)

    def __call__(self, data: Any) -> List[List[Dict[str, float]]]:
        """
        :param data: A dictionary contains `inputs` and optional `image` field.
        :return: A dictionary with `image` field contains image in base64.
        """
        prompt = data.pop("inputs", None)
        image = data.pop("image", None)
        controlnet_type = data.pop("controlnet_type", None)
        
        # Check if neither prompt nor image is provided
        if prompt is None and image is None:
            return {"error": "Please provide a prompt and base64 encoded image."}
        
        # Check if a new controlnet is provided
        if controlnet_type is not None and controlnet_type != self.control_type:
            print(f"changing controlnet from {self.control_type} to {controlnet_type} using {CONTROLNET_MAPPING[controlnet_type]['model_id']} model")
            self.control_type = controlnet_type
            self.controlnet = ControlNetModel.from_pretrained(CONTROLNET_MAPPING[self.control_type]["model_id"],
                                                              torch_dtype=dtype).to(device)
            self.pipe.controlnet = self.controlnet
        
        
        # hyperparamters
        negatice_prompt = data.pop("negative_prompt", None)
        num_inference_steps = data.pop("num_inference_steps", 30)
        guidance_scale = data.pop("guidance_scale", 7.0)
        negative_prompt = data.pop("negative_prompt", None)
        height = data.pop("height", None)
        width = data.pop("width", None)
        controlnet_conditioning_scale = data.pop("controlnet_conditioning_scale", 1.2)
        
        # process image
        image = self.decode_base64_image(image)
        #control_image = CONTROLNET_MAPPING[self.control_type]["hinter"](image)

        # run inference pipeline
        out = self.pipe(
            prompt=prompt, 
            negative_prompt=negative_prompt,
            #image=control_image,
            image=image,
            num_inference_steps=num_inference_steps, 
            guidance_scale=guidance_scale,
            num_images_per_prompt=1,
            height=height,
            width=width,
            controlnet_conditioning_scale=controlnet_conditioning_scale,
            generator=self.generator
        )

        
        # return first generate PIL image
        return out.images[0]
    
    # helper to decode input image
    def decode_base64_image(self, image_string):
        base64_image = base64.b64decode(image_string)
        buffer = BytesIO(base64_image)
        image = Image.open(buffer)
        return image