|
import torch |
|
from PIL import Image |
|
import qrcode |
|
from pathlib import Path |
|
from multiprocessing import cpu_count |
|
import requests |
|
import io |
|
import os |
|
from PIL import Image |
|
from diffusers import ( |
|
StableDiffusionPipeline, |
|
StableDiffusionControlNetImg2ImgPipeline, |
|
ControlNetModel, |
|
DDIMScheduler, |
|
DPMSolverMultistepScheduler, |
|
DEISMultistepScheduler, |
|
HeunDiscreteScheduler, |
|
EulerDiscreteScheduler, |
|
) |
|
|
|
|
|
SAMPLER_MAP = { |
|
"DPM++ Karras SDE": lambda config: DPMSolverMultistepScheduler.from_config(config, use_karras=True, algorithm_type="sde-dpmsolver++"), |
|
"DPM++ Karras": lambda config: DPMSolverMultistepScheduler.from_config(config, use_karras=True), |
|
"Heun": lambda config: HeunDiscreteScheduler.from_config(config), |
|
"Euler": lambda config: EulerDiscreteScheduler.from_config(config), |
|
"DDIM": lambda config: DDIMScheduler.from_config(config), |
|
"DEIS": lambda config: DEISMultistepScheduler.from_config(config), |
|
} |
|
|
|
|
|
def resize_for_condition_image(input_image: Image.Image, resolution: int): |
|
input_image = input_image.convert("RGB") |
|
W, H = input_image.size |
|
k = float(resolution) / min(H, W) |
|
H *= k |
|
W *= k |
|
H = int(round(H / 64.0)) * 64 |
|
W = int(round(W / 64.0)) * 64 |
|
img = input_image.resize((W, H), resample=Image.LANCZOS) |
|
return img |
|
|
|
class EndpointHandler(): |
|
def __init__(self, path=""): |
|
qrcode_generator = qrcode.QRCode( |
|
version=1, |
|
error_correction=qrcode.ERROR_CORRECT_H, |
|
box_size=10, |
|
border=4, |
|
) |
|
|
|
controlnet = ControlNetModel.from_pretrained( |
|
"DionTimmer/controlnet_qrcode-control_v1p_sd15", torch_dtype=torch.float16 |
|
) |
|
|
|
pipe = StableDiffusionControlNetImg2ImgPipeline.from_pretrained( |
|
"runwayml/stable-diffusion-v1-5", |
|
controlnet=controlnet, |
|
safety_checker=None, |
|
torch_dtype=torch.float16, |
|
).to("cuda") |
|
pipe.enable_xformers_memory_efficient_attention() |
|
|
|
|
|
def __call__inference(self, |
|
qr_code_content: str, |
|
prompt: str, |
|
negative_prompt: str, |
|
guidance_scale: float = 10.0, |
|
controlnet_conditioning_scale: float = 2.0, |
|
strength: float = 0.8, |
|
seed: int = -1, |
|
init_image: Image.Image | None = None, |
|
qrcode_image: Image.Image | None = None, |
|
use_qr_code_as_init_image = True, |
|
sampler = "DPM++ Karras SDE", |
|
): |
|
if prompt is None or prompt == "": |
|
raise gr.Error("Prompt is required") |
|
|
|
if qrcode_image is None and qr_code_content == "": |
|
raise gr.Error("QR Code Image or QR Code Content is required") |
|
|
|
pipe.scheduler = SAMPLER_MAP[sampler](pipe.scheduler.config) |
|
|
|
generator = torch.manual_seed(seed) if seed != -1 else torch.Generator() |
|
|
|
if qr_code_content != "" or qrcode_image.size == (1, 1): |
|
qr = qrcode.QRCode( |
|
version=1, |
|
error_correction=qrcode.constants.ERROR_CORRECT_H, |
|
box_size=10, |
|
border=4, |
|
) |
|
qr.add_data(qr_code_content) |
|
qr.make(fit=True) |
|
qrcode_image = qr.make_image(fill_color="black", back_color="white") |
|
|
|
if init_image is None: |
|
if use_qr_code_as_init_image: |
|
init_image = qrcode_image.convert("RGB") |
|
|
|
resolution = controlnet.config.resolution |
|
qrcode_image = resize_for_condition_image(qrcode_image, resolution) |
|
if init_image is not None: |
|
init_image = init_image.convert("RGB") |
|
init_image = resize_for_condition_image(init_image, resolution) |
|
init_image = torch.nn.functional.interpolate( |
|
torch.nn.functional.to_tensor(init_image).unsqueeze(0), |
|
size=(resolution, resolution), |
|
mode="bilinear", |
|
align_corners=False, |
|
)[0].unsqueeze(0) |
|
else: |
|
init_image = torch.zeros( |
|
(1, 3, resolution, resolution), device=pipe.device |
|
).to(dtype=torch.float32) |
|
|
|
with torch.no_grad(): |
|
result_image = pipe( |
|
qr_code_condition=qrcode_image, |
|
prompt=prompt, |
|
negative_prompt=negative_prompt, |
|
init_image=init_image, |
|
strength=strength, |
|
guidance_scale=guidance_scale, |
|
controlnet_conditioning_scale=controlnet_conditioning_scale, |
|
disable_progress_bar=True, |
|
seed=generator, |
|
) |
|
|
|
result_image = ( |
|
result_image.clamp(-1, 1).squeeze().permute(1, 2, 0).numpy() * 255 |
|
) |
|
result_image = Image.fromarray(result_image.astype("uint8")) |
|
|
|
return result_image |
|
|
|
|
|
|
|
|
|
|
|
|
|
|