File size: 4,752 Bytes
b754394
ac20447
 
 
 
 
 
 
 
 
 
b754394
ac20447
 
 
b754394
ac20447
 
 
 
 
 
 
 
 
 
b754394
 
 
 
ac20447
b754394
 
ac20447
 
 
 
 
b754394
2a4315a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b754394
 
 
 
 
 
 
c0859e3
b754394
 
 
 
ac20447
b754394
ac20447
b754394
ac20447
 
 
 
b754394
 
ac20447
 
b754394
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ac20447
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
from typing import List, Dict, Any
import base64
from PIL import Image
from io import BytesIO
from diffusers import StableDiffusionControlNetPipeline, ControlNetModel
import torch
import controlnet_hinter

# set device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
if device.type != 'cuda':
    raise ValueError("Need to run on GPU")
# set mixed precision dtype
dtype = torch.bfloat16 if torch.cuda.get_device_capability()[0] == 8 else torch.float16

# controlnet mapping for depth controlnet
CONTROLNET_MAPPING = {
    "depth": {
        "model_id": "lllyasviel/sd-controlnet-depth",
        "hinter": controlnet_hinter.hint_depth
    }
}

class EndpointHandler():
    def __init__(self, path=""):
        # define default controlnet id and load controlnet
        self.control_type = "depth"
        self.controlnet = ControlNetModel.from_pretrained(CONTROLNET_MAPPING[self.control_type]["model_id"], torch_dtype=dtype).to(device)

        # Load StableDiffusionControlNetPipeline
        self.stable_diffusion_id = "runwayml/stable-diffusion-v1-5"
        self.pipe = StableDiffusionControlNetPipeline.from_pretrained(self.stable_diffusion_id,
                                                                      controlnet=self.controlnet,
                                                                      torch_dtype=dtype,
                                                                      safety_checker=None).to(device)
        # Define Generator with seed
        self.generator = torch.Generator(device="cpu").manual_seed(3)

    def __call__(self, data: Any) -> Dict[str, str]:
        # Example JSON payload for testing
        example_payload = {
            "prompt": "a beautiful landscape",
            "negative_prompt": "blur",
            "width": 1024,
            "height": 1024,
            "steps": 30,
            "cfg_scale": 7,
            "alwayson_scripts": {
                "controlnet": {
                    "args": [
                        {
                            "enabled": True,
                            "input_image": "image in base64",
                            "model": "control_sd15_depth [fef5e48e]",
                            "control_mode": "Balanced"
                        }
                    ]
                }
            }
        }

        # Extract parameters from the payload
        prompt = data.get("prompt", None)
        negative_prompt = data.get("negative_prompt", None)
        width = data.get("width", None)
        height = data.get("height", None)
        num_inference_steps = data.get("steps", 30)
        guidance_scale = data.get("cfg_scale", 7)
        
        # Extract controlnet configuration from payload
        controlnet_config = data.get("alwayson_scripts", {}).get("controlnet", {}).get("args", [{}])[0]

        # Run stable diffusion process
        out = self.pipe(
            prompt=prompt,
            negative_prompt=negative_prompt,
            num_inference_steps=num_inference_steps,
            guidance_scale=guidance_scale,
            num_images_per_prompt=1,
            height=height,
            width=width,
            controlnet_conditioning_scale=1.0,
            generator=self.generator,
        )

        # Get the generated image
        generated_image = out.images[0]

        # Process with controlnet if enabled
        if controlnet_config.get("enabled", False):
            input_image_base64 = controlnet_config.get("input_image", "")
            input_image = self.decode_base64_image(input_image_base64)
            controlnet_model = controlnet_config.get("model", "")
            controlnet_control_mode = controlnet_config.get("control_mode", "")
            
            processed_image = self.process_with_controlnet(generated_image, input_image, controlnet_model, controlnet_control_mode)
        else:
            processed_image = generated_image

        # Return the final processed image as base64
        return {"image": self.encode_base64_image(processed_image)}

    def process_with_controlnet(self, generated_image, input_image, model, control_mode):
        # Simulated controlnet processing (replace with actual implementation)
        # Here, we're just using the input_image as-is. Replace this with your controlnet logic.
        return input_image

    def encode_base64_image(self, image):
        # Encode the PIL Image to base64
        buffer = BytesIO()
        image.save(buffer, format="PNG")
        return base64.b64encode(buffer.getvalue()).decode("utf-8")

    def decode_base64_image(self, image_string):
        base64_image = base64.b64decode(image_string)
        buffer = BytesIO(base64_image)
        image = Image.open(buffer)
        return image