File size: 2,827 Bytes
c4b70cd
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
from typing import Any
import torch, base64
from PIL import Image
from diffusers import StableDiffusionControlNetImg2ImgPipeline, ControlNetModel, DDIMScheduler
from diffusers.utils import load_image
from io import BytesIO

class EndpointHandler():
    def __init__(self, path=""):
        self.controlnet = ControlNetModel.from_pretrained("DionTimmer/controlnet_qrcode-control_v11p_sd21", torch_dtype=torch.float16)
        self.pipe = StableDiffusionControlNetImg2ImgPipeline.from_pretrained("stabilityai/stable-diffusion-2-1", controlnet=self.controlnet, safety_checker=None, torch_dtype=torch.float16)

        self.pipe.enable_xformers_memory_efficient_attention()
        self.pipe.scheduler = DDIMScheduler.from_config(self.pipe.scheduler.config)
        self.pipe.enable_model_cpu_offload()

    def __call__(self, data):
        """
        data args:
            inputs (:obj: `str`)
            date (:obj: `str`)
        Return:
            A :obj:`list` | `dict`: will be serialized and returned
        """
        # get inputs
        inputs = data.pop("inputs", data)
        params = data.pop("parameters", data)
        prompt = params.get("prompt")
        negative_prompt = params.get("negative_prompt")

        def resize_image(input_image: Image, resolution: int):
            input_image = input_image.convert("RGB")
            W, H = input_image.size
            k = float(resolution) / min(H, W)
            H *= k
            W *= k
            H = int(round(H / 64.0)) * 64
            W = int(round(W / 64.0)) * 64
            img = input_image.resize((W, H), resample=Image.LANCZOS)
            return img

        orriginal_qr_code_image = load_image(inputs)
        img_path = 'https://images.squarespace-cdn.com/content/v1/59413d96e6f2e1c6837c7ecd/1536503659130-R84NUPOY4QPQTEGCTSAI/15fe1e62172035.5a87280d713e4.png'


        init_image = load_image(img_path)
        condition_image = resize_image(orriginal_qr_code_image, 768)
        init_image = resize_image(init_image, 768)
        generator = torch.manual_seed(123121231)
        image = self.pipe(prompt=prompt or "a bilboard in NYC with a qrcode",
                negative_prompt=negative_prompt or "ugly, disfigured, low quality, blurry, nsfw, worst quality, illustration, drawing",
                image=init_image,
                control_image=condition_image,
                width=768,
                height=768,
                guidance_scale=20,
                controlnet_conditioning_scale=2.5,
                generator=generator,
                strength=0.9,
                num_inference_steps=150,
            )

        image = image.images[0]
        buffered = BytesIO()
        image.save(buffered, format="JPEG")
        img_str = base64.b64encode(buffered.getvalue())

        return {"image": img_str.decode()}