owaiskaifi commited on
Commit
6a5d151
1 Parent(s): 66daff5

Upload handler.py

Browse files
Files changed (1) hide show
  1. handler.py +144 -0
handler.py ADDED
@@ -0,0 +1,144 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import gradio as gr
3
+ from PIL import Image
4
+ import qrcode
5
+ from pathlib import Path
6
+ from multiprocessing import cpu_count
7
+ import requests
8
+ import io
9
+ import os
10
+ from PIL import Image
11
+ from diffusers import (
12
+ StableDiffusionPipeline,
13
+ StableDiffusionControlNetImg2ImgPipeline,
14
+ ControlNetModel,
15
+ DDIMScheduler,
16
+ DPMSolverMultistepScheduler,
17
+ DEISMultistepScheduler,
18
+ HeunDiscreteScheduler,
19
+ EulerDiscreteScheduler,
20
+ )
21
+
22
+
23
+ SAMPLER_MAP = {
24
+ "DPM++ Karras SDE": lambda config: DPMSolverMultistepScheduler.from_config(config, use_karras=True, algorithm_type="sde-dpmsolver++"),
25
+ "DPM++ Karras": lambda config: DPMSolverMultistepScheduler.from_config(config, use_karras=True),
26
+ "Heun": lambda config: HeunDiscreteScheduler.from_config(config),
27
+ "Euler": lambda config: EulerDiscreteScheduler.from_config(config),
28
+ "DDIM": lambda config: DDIMScheduler.from_config(config),
29
+ "DEIS": lambda config: DEISMultistepScheduler.from_config(config),
30
+ }
31
+
32
+
33
+ def resize_for_condition_image(input_image: Image.Image, resolution: int):
34
+ input_image = input_image.convert("RGB")
35
+ W, H = input_image.size
36
+ k = float(resolution) / min(H, W)
37
+ H *= k
38
+ W *= k
39
+ H = int(round(H / 64.0)) * 64
40
+ W = int(round(W / 64.0)) * 64
41
+ img = input_image.resize((W, H), resample=Image.LANCZOS)
42
+ return img
43
+
44
+ class EndpointHandler():
45
+ def __init__(self, path=""):
46
+ qrcode_generator = qrcode.QRCode(
47
+ version=1,
48
+ error_correction=qrcode.ERROR_CORRECT_H,
49
+ box_size=10,
50
+ border=4,
51
+ )
52
+
53
+ controlnet = ControlNetModel.from_pretrained(
54
+ "DionTimmer/controlnet_qrcode-control_v1p_sd15", torch_dtype=torch.float16
55
+ )
56
+
57
+ pipe = StableDiffusionControlNetImg2ImgPipeline.from_pretrained(
58
+ "runwayml/stable-diffusion-v1-5",
59
+ controlnet=controlnet,
60
+ safety_checker=None,
61
+ torch_dtype=torch.float16,
62
+ ).to("cuda")
63
+ pipe.enable_xformers_memory_efficient_attention()
64
+
65
+
66
+ def __call__inference(self,
67
+ qr_code_content: str,
68
+ prompt: str,
69
+ negative_prompt: str,
70
+ guidance_scale: float = 10.0,
71
+ controlnet_conditioning_scale: float = 2.0,
72
+ strength: float = 0.8,
73
+ seed: int = -1,
74
+ init_image: Image.Image | None = None,
75
+ qrcode_image: Image.Image | None = None,
76
+ use_qr_code_as_init_image = True,
77
+ sampler = "DPM++ Karras SDE",
78
+ ):
79
+ if prompt is None or prompt == "":
80
+ raise gr.Error("Prompt is required")
81
+
82
+ if qrcode_image is None and qr_code_content == "":
83
+ raise gr.Error("QR Code Image or QR Code Content is required")
84
+
85
+ pipe.scheduler = SAMPLER_MAP[sampler](pipe.scheduler.config)
86
+
87
+ generator = torch.manual_seed(seed) if seed != -1 else torch.Generator()
88
+
89
+ if qr_code_content != "" or qrcode_image.size == (1, 1):
90
+ qr = qrcode.QRCode(
91
+ version=1,
92
+ error_correction=qrcode.constants.ERROR_CORRECT_H,
93
+ box_size=10,
94
+ border=4,
95
+ )
96
+ qr.add_data(qr_code_content)
97
+ qr.make(fit=True)
98
+ qrcode_image = qr.make_image(fill_color="black", back_color="white")
99
+
100
+ if init_image is None:
101
+ if use_qr_code_as_init_image:
102
+ init_image = qrcode_image.convert("RGB")
103
+
104
+ resolution = controlnet.config.resolution
105
+ qrcode_image = resize_for_condition_image(qrcode_image, resolution)
106
+ if init_image is not None:
107
+ init_image = init_image.convert("RGB")
108
+ init_image = resize_for_condition_image(init_image, resolution)
109
+ init_image = torch.nn.functional.interpolate(
110
+ torch.nn.functional.to_tensor(init_image).unsqueeze(0),
111
+ size=(resolution, resolution),
112
+ mode="bilinear",
113
+ align_corners=False,
114
+ )[0].unsqueeze(0)
115
+ else:
116
+ init_image = torch.zeros(
117
+ (1, 3, resolution, resolution), device=pipe.device
118
+ ).to(dtype=torch.float32)
119
+
120
+ with torch.no_grad():
121
+ result_image = pipe(
122
+ qr_code_condition=qrcode_image,
123
+ prompt=prompt,
124
+ negative_prompt=negative_prompt,
125
+ init_image=init_image,
126
+ strength=strength,
127
+ guidance_scale=guidance_scale,
128
+ controlnet_conditioning_scale=controlnet_conditioning_scale,
129
+ disable_progress_bar=True,
130
+ seed=generator,
131
+ ).cpu()
132
+
133
+ result_image = (
134
+ result_image.clamp(-1, 1).squeeze().permute(1, 2, 0).numpy() * 255
135
+ )
136
+ result_image = Image.fromarray(result_image.astype("uint8"))
137
+
138
+ return result_image
139
+
140
+
141
+
142
+
143
+
144
+