File size: 4,871 Bytes
6a5d151
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ec30585
6a5d151
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
import torch
from PIL import Image
import qrcode
from pathlib import Path
from multiprocessing import cpu_count
import requests
import io
import os
from PIL import Image
from diffusers import (
    StableDiffusionPipeline,
    StableDiffusionControlNetImg2ImgPipeline,
    ControlNetModel,
    DDIMScheduler,
    DPMSolverMultistepScheduler,
    DEISMultistepScheduler,
    HeunDiscreteScheduler,
    EulerDiscreteScheduler,
)


SAMPLER_MAP = {
        "DPM++ Karras SDE": lambda config: DPMSolverMultistepScheduler.from_config(config, use_karras=True, algorithm_type="sde-dpmsolver++"),
        "DPM++ Karras": lambda config: DPMSolverMultistepScheduler.from_config(config, use_karras=True),
        "Heun": lambda config: HeunDiscreteScheduler.from_config(config),
        "Euler": lambda config: EulerDiscreteScheduler.from_config(config),
        "DDIM": lambda config: DDIMScheduler.from_config(config),
        "DEIS": lambda config: DEISMultistepScheduler.from_config(config),
    }


def resize_for_condition_image(input_image: Image.Image, resolution: int):
        input_image = input_image.convert("RGB")
        W, H = input_image.size
        k = float(resolution) / min(H, W)
        H *= k
        W *= k
        H = int(round(H / 64.0)) * 64
        W = int(round(W / 64.0)) * 64
        img = input_image.resize((W, H), resample=Image.LANCZOS)
        return img

class EndpointHandler():
    def __init__(self, path=""):
        qrcode_generator = qrcode.QRCode(
            version=1,
            error_correction=qrcode.ERROR_CORRECT_H,
            box_size=10,
            border=4,
)

        controlnet = ControlNetModel.from_pretrained(
            "DionTimmer/controlnet_qrcode-control_v1p_sd15", torch_dtype=torch.float16
        )

        pipe = StableDiffusionControlNetImg2ImgPipeline.from_pretrained(
            "runwayml/stable-diffusion-v1-5",
            controlnet=controlnet,
            safety_checker=None,
            torch_dtype=torch.float16,
        ).to("cuda")
        pipe.enable_xformers_memory_efficient_attention()


    def __call__inference(self,
        qr_code_content: str,
        prompt: str,
        negative_prompt: str,
        guidance_scale: float = 10.0,
        controlnet_conditioning_scale: float = 2.0,
        strength: float = 0.8,
        seed: int = -1,
        init_image: Image.Image | None = None,
        qrcode_image: Image.Image | None = None,
        use_qr_code_as_init_image = True,
        sampler = "DPM++ Karras SDE",
    ):
        if prompt is None or prompt == "":
            raise gr.Error("Prompt is required")

        if qrcode_image is None and qr_code_content == "":
            raise gr.Error("QR Code Image or QR Code Content is required")

        pipe.scheduler = SAMPLER_MAP[sampler](pipe.scheduler.config)

        generator = torch.manual_seed(seed) if seed != -1 else torch.Generator()

        if qr_code_content != "" or qrcode_image.size == (1, 1):
            qr = qrcode.QRCode(
                version=1,
                error_correction=qrcode.constants.ERROR_CORRECT_H,
                box_size=10,
                border=4,
            )
            qr.add_data(qr_code_content)
            qr.make(fit=True)
            qrcode_image = qr.make_image(fill_color="black", back_color="white")

        if init_image is None:
            if use_qr_code_as_init_image:
                init_image = qrcode_image.convert("RGB")

        resolution = controlnet.config.resolution
        qrcode_image = resize_for_condition_image(qrcode_image, resolution)
        if init_image is not None:
            init_image = init_image.convert("RGB")
            init_image = resize_for_condition_image(init_image, resolution)
            init_image = torch.nn.functional.interpolate(
                torch.nn.functional.to_tensor(init_image).unsqueeze(0),
                size=(resolution, resolution),
                mode="bilinear",
                align_corners=False,
            )[0].unsqueeze(0)
        else:
            init_image = torch.zeros(
                (1, 3, resolution, resolution), device=pipe.device
            ).to(dtype=torch.float32)

        with torch.no_grad():
            result_image = pipe(
                qr_code_condition=qrcode_image,
                prompt=prompt,
                negative_prompt=negative_prompt,
                init_image=init_image,
                strength=strength,
                guidance_scale=guidance_scale,
                controlnet_conditioning_scale=controlnet_conditioning_scale,
                disable_progress_bar=True,
                seed=generator,
            )

        result_image = (
            result_image.clamp(-1, 1).squeeze().permute(1, 2, 0).numpy() * 255
        )
        result_image = Image.fromarray(result_image.astype("uint8"))

        return result_image