Spaces:
Runtime error
Runtime error
import torch | |
import gradio as gr | |
from PIL import Image | |
import qrcode | |
import os | |
from diffusers import ( | |
StableDiffusionControlNetPipeline, | |
ControlNetModel, | |
DDIMScheduler, | |
DPMSolverMultistepScheduler, | |
UniPCMultistepScheduler, | |
DEISMultistepScheduler, | |
HeunDiscreteScheduler, | |
EulerDiscreteScheduler, | |
EulerAncestralDiscreteScheduler, | |
) | |
controlnet = ControlNetModel.from_pretrained( | |
"monster-labs/control_v1p_sd15_qrcode_monster", | |
torch_dtype=torch.float16, | |
) | |
pipe = StableDiffusionControlNetPipeline.from_pretrained( | |
#"runwayml/stable-diffusion-v1-5", | |
"SG161222/Realistic_Vision_V3.0_VAE", | |
controlnet=controlnet, | |
safety_checker=None, | |
torch_dtype=torch.float16, | |
).to("cuda") | |
#pipe.enable_xformers_memory_efficient_attention() | |
pipe.enable_attention_slicing(1) | |
pipe.enable_model_cpu_offload() | |
#pipe.enable_vae_tiling() | |
pipe.enable_vae_slicing() | |
pipe.scheduler = UniPCMultistepScheduler.from_config(pipe.scheduler.config) | |
SAMPLER_MAP = { | |
"DPM++ Karras SDE": lambda config: DPMSolverMultistepScheduler.from_config(config, use_karras=True, algorithm_type="sde-dpmsolver++"), | |
"DPM++ Karras": lambda config: DPMSolverMultistepScheduler.from_config(config, use_karras=True), | |
"Heun": lambda config: HeunDiscreteScheduler.from_config(config), | |
"Euler a": lambda config: EulerAncestralDiscreteScheduler.from_config(config), | |
"Euler": lambda config: EulerDiscreteScheduler.from_config(config), | |
"DDIM": lambda config: DDIMScheduler.from_config(config), | |
"DEIS": lambda config: DEISMultistepScheduler.from_config(config), | |
} | |
boxsize=16 | |
def create_code(content: str, errorCorrection: str): | |
match errorCorrection: | |
case "L 7%": | |
errCorr = qrcode.constants.ERROR_CORRECT_L | |
case "M 15%": | |
errCorr = qrcode.constants.ERROR_CORRECT_M | |
case "Q 25%": | |
errCorr = qrcode.constants.ERROR_CORRECT_Q | |
case "H 30%": | |
errCorr = qrcode.constants.ERROR_CORRECT_H | |
qr = qrcode.QRCode( | |
version=1, | |
error_correction=errCorr, | |
box_size=boxsize, | |
border=0, | |
) | |
qr.add_data(content) | |
qr.make(fit=True) | |
img = qr.make_image(fill_color="black", back_color="white") | |
# find smallest image size multiple of 256 that can fit qr | |
offset_min = 8 * boxsize | |
w, h = img.size | |
w = (w + 255 + offset_min) // 256 * 256 | |
h = (h + 255 + offset_min) // 256 * 256 | |
if w > 1024: | |
raise gr.Error("QR code is too large, please use a shorter content") | |
bg = Image.new('L', (w, h), 128) | |
# align on 16px grid | |
coords = ((w - img.size[0]) // 2 // boxsize * boxsize, | |
(h - img.size[1]) // 2 // boxsize * boxsize) | |
bg.paste(img, coords) | |
return bg | |
def inference( | |
qr_code_content: str, | |
errorCorrection: str, | |
prompt: str, | |
negative_prompt: str, | |
inferenceSteps: float, | |
guidance_scale: float = 10.0, | |
controlnet_conditioning_scale: float = 2.0, | |
seed: int = -1, | |
sampler="Euler a", | |
): | |
if prompt is None or prompt == "": | |
raise gr.Error("Prompt is required") | |
if qr_code_content is None or qr_code_content == "": | |
raise gr.Error("QR Code Content is required") | |
pipe.scheduler = SAMPLER_MAP[sampler](pipe.scheduler.config) | |
generator = torch.manual_seed(seed) if seed != -1 else torch.Generator() | |
print("Generating QR Code from content") | |
qrcode_image = create_code(qr_code_content, errorCorrection) | |
# hack due to gradio examples | |
init_image = qrcode_image | |
init_image.save("c:\\temp\\qr.jpg") | |
out = pipe( | |
prompt=prompt, | |
negative_prompt=negative_prompt, | |
image=qrcode_image, | |
width=qrcode_image.width, | |
height=qrcode_image.height, | |
guidance_scale=float(guidance_scale), | |
controlnet_conditioning_scale=float(controlnet_conditioning_scale), | |
generator=generator, | |
num_inference_steps=inferenceSteps, | |
) | |
return out.images[0] | |
css = """ | |
#result_image { | |
display: flex; | |
place-content: center; | |
align-items: center; | |
} | |
#result_image > img { | |
height: auto; | |
max-width: 100%; | |
width: revert; | |
} | |
""" | |
with gr.Blocks(css=css) as blocks: | |
with gr.Row(): | |
with gr.Column(): | |
qr_code_content = gr.Textbox( | |
label="QR Code Content or URL", | |
info="The text you want to encode into the QR code", | |
value="", | |
) | |
errorCorrection = gr.Dropdown( | |
label="QR Code Error Correction Level", | |
choices=["L 7%", "M 15%", "Q 25%", "H 30%"], | |
value="H 30%" | |
) | |
prompt = gr.Textbox( | |
label="Prompt", | |
info="Prompt that guides the generation towards", | |
) | |
negative_prompt = gr.Textbox( | |
label="Negative Prompt", | |
value="ugly, disfigured, low quality, blurry, nsfw", | |
info="Prompt that guides the generation away from", | |
) | |
inferenceSteps = gr.Slider( | |
minimum=10.0, | |
maximum=60.0, | |
step=1, | |
value=20, | |
label="Inference Steps", | |
info="More steps give better image but longer runtime", | |
) | |
with gr.Accordion( | |
label="Params: The generated QR Code functionality is largely influenced by the parameters detailed below", | |
open=True, | |
): | |
controlnet_conditioning_scale = gr.Slider( | |
minimum=0.5, | |
maximum=2.5, | |
step=0.01, | |
value=1.5, | |
label="Controlnet Conditioning Scale", | |
info="""Controls the readability/creativity of the QR code. | |
High values: The generated QR code will be more readable. | |
Low values: The generated QR code will be more creative. | |
""" | |
) | |
guidance_scale = gr.Slider( | |
minimum=0.0, | |
maximum=25.0, | |
step=0.25, | |
value=7, | |
label="Guidance Scale", | |
info="Controls the amount of guidance the text prompt guides the image generation" | |
) | |
sampler = gr.Dropdown(choices=list( | |
SAMPLER_MAP.keys()), value="Euler a", label="Sampler") | |
seed = gr.Number( | |
minimum=-1, | |
maximum=9999999999, | |
value=-1, | |
label="Seed", | |
info="Seed for the random number generator. Set to -1 for a random seed" | |
) | |
with gr.Row(): | |
run_btn = gr.Button("Run") | |
with gr.Column(): | |
result_image = gr.Image(label="Result Image", elem_id="result_image") | |
run_btn.click( | |
inference, | |
inputs=[ | |
qr_code_content, | |
errorCorrection, | |
prompt, | |
negative_prompt, | |
inferenceSteps, | |
guidance_scale, | |
controlnet_conditioning_scale, | |
seed, | |
sampler, | |
], | |
outputs=[result_image], | |
) | |
blocks.queue(concurrency_count=1, max_size=20, api_open=False) | |
blocks.launch(share=bool(os.environ.get("SHARE", True)), show_api=False) |