File size: 4,644 Bytes
aca26e9
 
 
 
 
aeba19d
aca26e9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7010589
 
 
 
 
 
aca26e9
 
 
 
 
 
7010589
aca26e9
 
 
 
 
 
aeba19d
aca26e9
 
 
 
 
 
 
 
e93d2b4
aca26e9
ff6535b
 
aca26e9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
30e8f60
 
9b92fa5
aca26e9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9b92fa5
 
aca26e9
 
 
 
 
 
 
 
 
 
9b92fa5
 
aca26e9
 
 
 
 
 
 
 
 
 
8330e27
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
from typing import Dict, List, Any
import torch
import base64
from PIL import Image
from io import BytesIO
from diffusers import T2IAdapter, StableDiffusionXLAdapterPipeline, StableDiffusionXLImg2ImgPipeline, AutoencoderKL, DPMSolverMultistepScheduler
from controlnet_aux.pidi import PidiNetDetector

# set device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

if device.type != 'cuda':
    raise ValueError("need to run on GPU")

class EndpointHandler():
    # Preload all the elements you are going to need at inference.
    def __init__(self, path=""):

        # load the T2I adapter
        adapter = T2IAdapter.from_pretrained(
            "Adapter/t2iadapter",
            subfolder="sketch_sdxl_1.0",
            torch_dtype=torch.float16,
            adapter_type="full_adapter_xl",
            use_safetensors=True,
        )

        # load variational autoencoder (VAE)
        vae = AutoencoderKL.from_pretrained(
            "madebyollin/sdxl-vae-fp16-fix",
            torch_dtype=torch.float16,
            use_safetensors=True,
        )

        # load the scheduler
        scheduler = DPMSolverMultistepScheduler.from_pretrained(
            "stabilityai/stable-diffusion-xl-base-1.0",
            subfolder="scheduler",
            use_lu_lambdas=True,
            euler_at_final=True,
        )

        # instantiate HF pipeline to combine all the components
        self.pipeline = StableDiffusionXLAdapterPipeline.from_pretrained(
            "stabilityai/stable-diffusion-xl-base-1.0",
            adapter=adapter,
            vae=vae,
            scheduler=scheduler,
            torch_dtype=torch.float16,
            variant="fp16",
            use_safetensors=True,
        ).to("cuda")

        # instantiate HF refiner to improve output image
        self.refiner = StableDiffusionXLImg2ImgPipeline.from_pretrained(
            "stabilityai/stable-diffusion-xl-refiner-1.0",
            text_encoder_2=self.pipeline.text_encoder_2,
            vae=vae,
            torch_dtype=torch.float16,
            variant="fp16",
            use_safetensors=True,
        ).to("cuda")

        self.pipeline.enable_model_cpu_offload()
        self.refiner.enable_model_cpu_offload()

        self.pidinet = PidiNetDetector.from_pretrained("lllyasviel/Annotators").to("cuda")

    def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]:
        """
        data args:
            inputs (:obj: `str` | `PIL.Image` | `np.array`)
            kwargs
        Return:
            A :obj:`list` | `dict`: will be serialized and returned
        """

        # pseudo
        # self.model(input)

        # get inputs
        inputs = data.pop("inputs", "")
        encoded_image = data.pop("image", None)
        adapter_conditioning_scale = data.pop("adapter_conditioning_scale", 1.0)
        adapter_conditioning_factor = data.pop("adapter_conditioning_factor", 1.0)


        # Decode image and convert to black and white sketch
        decoded_image = self.decode_base64_image(encoded_image).convert('RGB')
        sketch_image = self.pidinet(
            decoded_image,
            detect_resolution=1024,
            image_resolution=1024,
            apply_filter=True
        ).convert('L')

        # sketch_image.save("./output1.png")

        num_inference_steps = 25
        high_noise_frac = 0.7
        base_image = self.pipeline(
            prompt=inputs,
            negative_prompt="extra digit, fewer digits, cropped, worst quality, low quality",
            image=sketch_image,
            num_inference_steps=num_inference_steps,
            denoising_end=high_noise_frac,
            guidance_scale=7.5,
            adapter_conditioning_scale=adapter_conditioning_scale,
            adapter_conditioning_factor=adapter_conditioning_factor,
            output_type="latent",
        ).images

        output_image = self.refiner(
            prompt=inputs,
            negative_prompt="extra digit, fewer digits, cropped, worst quality, low quality",
            image=base_image,
            num_inference_steps=num_inference_steps,
            denoising_start=high_noise_frac,
            guidance_scale=7.5,
            adapter_conditioning_scale=adapter_conditioning_scale,
            adapter_conditioning_factor=adapter_conditioning_factor,
        ).images[0]

        # output_image.save("./output2.png")
        return output_image

    # helper to decode input image
    def decode_base64_image(self, image_string):
        base64_image = base64.b64decode(image_string)
        buffer = BytesIO(base64_image)
        image = Image.open(buffer)
        return image