brainstone commited on
Commit
0c8abab
1 Parent(s): 5d45a2a

Upload QRDiffuser.py

Browse files
Files changed (1) hide show
  1. QRDiffuser.py +227 -0
QRDiffuser.py ADDED
@@ -0,0 +1,227 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import gradio as gr
3
+ from PIL import Image
4
+ import qrcode
5
+ import os
6
+
7
+ from diffusers import (
8
+ StableDiffusionControlNetPipeline,
9
+ ControlNetModel,
10
+ DDIMScheduler,
11
+ DPMSolverMultistepScheduler,
12
+ UniPCMultistepScheduler,
13
+ DEISMultistepScheduler,
14
+ HeunDiscreteScheduler,
15
+ EulerDiscreteScheduler,
16
+ EulerAncestralDiscreteScheduler,
17
+ )
18
+
19
+ controlnet = ControlNetModel.from_pretrained(
20
+ "monster-labs/control_v1p_sd15_qrcode_monster",
21
+ torch_dtype=torch.float16,
22
+ )
23
+
24
+ pipe = StableDiffusionControlNetPipeline.from_pretrained(
25
+ #"runwayml/stable-diffusion-v1-5",
26
+ "SG161222/Realistic_Vision_V3.0_VAE",
27
+ controlnet=controlnet,
28
+ safety_checker=None,
29
+ torch_dtype=torch.float16,
30
+ ).to("cuda")
31
+ #pipe.enable_xformers_memory_efficient_attention()
32
+ pipe.enable_attention_slicing(1)
33
+ pipe.enable_model_cpu_offload()
34
+ #pipe.enable_vae_tiling()
35
+ pipe.enable_vae_slicing()
36
+ pipe.scheduler = UniPCMultistepScheduler.from_config(pipe.scheduler.config)
37
+
38
+ SAMPLER_MAP = {
39
+ "DPM++ Karras SDE": lambda config: DPMSolverMultistepScheduler.from_config(config, use_karras=True, algorithm_type="sde-dpmsolver++"),
40
+ "DPM++ Karras": lambda config: DPMSolverMultistepScheduler.from_config(config, use_karras=True),
41
+ "Heun": lambda config: HeunDiscreteScheduler.from_config(config),
42
+ "Euler a": lambda config: EulerAncestralDiscreteScheduler.from_config(config),
43
+ "Euler": lambda config: EulerDiscreteScheduler.from_config(config),
44
+ "DDIM": lambda config: DDIMScheduler.from_config(config),
45
+ "DEIS": lambda config: DEISMultistepScheduler.from_config(config),
46
+ }
47
+
48
+ boxsize=16
49
+ def create_code(content: str, errorCorrection: str):
50
+ match errorCorrection:
51
+ case "L 7%":
52
+ errCorr = qrcode.constants.ERROR_CORRECT_L
53
+ case "M 15%":
54
+ errCorr = qrcode.constants.ERROR_CORRECT_M
55
+ case "Q 25%":
56
+ errCorr = qrcode.constants.ERROR_CORRECT_Q
57
+ case "H 30%":
58
+ errCorr = qrcode.constants.ERROR_CORRECT_H
59
+
60
+ qr = qrcode.QRCode(
61
+ version=1,
62
+ error_correction=errCorr,
63
+ box_size=boxsize,
64
+ border=0,
65
+ )
66
+ qr.add_data(content)
67
+ qr.make(fit=True)
68
+ img = qr.make_image(fill_color="black", back_color="white")
69
+
70
+ # find smallest image size multiple of 256 that can fit qr
71
+ offset_min = 8 * boxsize
72
+ w, h = img.size
73
+ w = (w + 255 + offset_min) // 256 * 256
74
+ h = (h + 255 + offset_min) // 256 * 256
75
+ if w > 1024:
76
+ raise gr.Error("QR code is too large, please use a shorter content")
77
+ bg = Image.new('L', (w, h), 128)
78
+
79
+ # align on 16px grid
80
+ coords = ((w - img.size[0]) // 2 // boxsize * boxsize,
81
+ (h - img.size[1]) // 2 // boxsize * boxsize)
82
+ bg.paste(img, coords)
83
+ return bg
84
+
85
+
86
+ def inference(
87
+ qr_code_content: str,
88
+ errorCorrection: str,
89
+ prompt: str,
90
+ negative_prompt: str,
91
+ inferenceSteps: float,
92
+ guidance_scale: float = 10.0,
93
+ controlnet_conditioning_scale: float = 2.0,
94
+ seed: int = -1,
95
+ sampler="Euler a",
96
+ ):
97
+ if prompt is None or prompt == "":
98
+ raise gr.Error("Prompt is required")
99
+
100
+ if qr_code_content is None or qr_code_content == "":
101
+ raise gr.Error("QR Code Content is required")
102
+
103
+ pipe.scheduler = SAMPLER_MAP[sampler](pipe.scheduler.config)
104
+
105
+ generator = torch.manual_seed(seed) if seed != -1 else torch.Generator()
106
+
107
+ print("Generating QR Code from content")
108
+ qrcode_image = create_code(qr_code_content, errorCorrection)
109
+
110
+ # hack due to gradio examples
111
+ init_image = qrcode_image
112
+ init_image.save("c:\\temp\\qr.jpg")
113
+
114
+ out = pipe(
115
+ prompt=prompt,
116
+ negative_prompt=negative_prompt,
117
+ image=qrcode_image,
118
+ width=qrcode_image.width,
119
+ height=qrcode_image.height,
120
+ guidance_scale=float(guidance_scale),
121
+ controlnet_conditioning_scale=float(controlnet_conditioning_scale),
122
+ generator=generator,
123
+ num_inference_steps=inferenceSteps,
124
+ )
125
+ return out.images[0]
126
+
127
+
128
+ css = """
129
+ #result_image {
130
+ display: flex;
131
+ place-content: center;
132
+ align-items: center;
133
+ }
134
+ #result_image > img {
135
+ height: auto;
136
+ max-width: 100%;
137
+ width: revert;
138
+ }
139
+ """
140
+
141
+ with gr.Blocks(css=css) as blocks:
142
+
143
+ with gr.Row():
144
+ with gr.Column():
145
+ qr_code_content = gr.Textbox(
146
+ label="QR Code Content or URL",
147
+ info="The text you want to encode into the QR code",
148
+ value="",
149
+ )
150
+ errorCorrection = gr.Dropdown(
151
+ label="QR Code Error Correction Level",
152
+ choices=["L 7%", "M 15%", "Q 25%", "H 30%"],
153
+ value="H 30%"
154
+ )
155
+
156
+ prompt = gr.Textbox(
157
+ label="Prompt",
158
+ info="Prompt that guides the generation towards",
159
+ )
160
+ negative_prompt = gr.Textbox(
161
+ label="Negative Prompt",
162
+ value="ugly, disfigured, low quality, blurry, nsfw",
163
+ info="Prompt that guides the generation away from",
164
+ )
165
+ inferenceSteps = gr.Slider(
166
+ minimum=10.0,
167
+ maximum=60.0,
168
+ step=1,
169
+ value=20,
170
+ label="Inference Steps",
171
+ info="More steps give better image but longer runtime",
172
+ )
173
+
174
+ with gr.Accordion(
175
+ label="Params: The generated QR Code functionality is largely influenced by the parameters detailed below",
176
+ open=True,
177
+ ):
178
+ controlnet_conditioning_scale = gr.Slider(
179
+ minimum=0.5,
180
+ maximum=2.5,
181
+ step=0.01,
182
+ value=1.5,
183
+ label="Controlnet Conditioning Scale",
184
+ info="""Controls the readability/creativity of the QR code.
185
+ High values: The generated QR code will be more readable.
186
+ Low values: The generated QR code will be more creative.
187
+ """
188
+ )
189
+ guidance_scale = gr.Slider(
190
+ minimum=0.0,
191
+ maximum=25.0,
192
+ step=0.25,
193
+ value=7,
194
+ label="Guidance Scale",
195
+ info="Controls the amount of guidance the text prompt guides the image generation"
196
+ )
197
+ sampler = gr.Dropdown(choices=list(
198
+ SAMPLER_MAP.keys()), value="Euler a", label="Sampler")
199
+ seed = gr.Number(
200
+ minimum=-1,
201
+ maximum=9999999999,
202
+ value=-1,
203
+ label="Seed",
204
+ info="Seed for the random number generator. Set to -1 for a random seed"
205
+ )
206
+ with gr.Row():
207
+ run_btn = gr.Button("Run")
208
+ with gr.Column():
209
+ result_image = gr.Image(label="Result Image", elem_id="result_image")
210
+ run_btn.click(
211
+ inference,
212
+ inputs=[
213
+ qr_code_content,
214
+ errorCorrection,
215
+ prompt,
216
+ negative_prompt,
217
+ inferenceSteps,
218
+ guidance_scale,
219
+ controlnet_conditioning_scale,
220
+ seed,
221
+ sampler,
222
+ ],
223
+ outputs=[result_image],
224
+ )
225
+
226
+ blocks.queue(concurrency_count=1, max_size=20, api_open=False)
227
+ blocks.launch(share=bool(os.environ.get("SHARE", True)), show_api=False)