| | import torch |
| | from diffusers import ( |
| | StableDiffusionControlNetPipeline, |
| | ControlNetModel, |
| | EulerAncestralDiscreteScheduler, |
| | ) |
| | from typing import Dict, List, Any |
| | from xformers.ops import MemoryEfficientAttentionFlashAttentionOp |
| | import qrcode |
| | import os |
| | import base64 |
| | from io import BytesIO |
| | import json |
| | from PIL import Image |
| |
|
| | MODEL_ID = "simdi/colorful_qr" |
| | WIDTH = 768 |
| | HEIGHT = 768 |
| |
|
| | WEIGHT_PAIRS = [ |
| | (0.25, 0.20), |
| | (0.25, 0.25), |
| | (0.35, 0.20), |
| | (0.35, 0.25), |
| | (0.45, 0.20), |
| | (0.45, 0.25), |
| | ] |
| |
|
| |
|
| | def float_to_pair_index(f: float): |
| | length = len(WEIGHT_PAIRS) |
| | |
| | if False: |
| | return int(f) |
| | |
| | else: |
| | |
| | f = max(0.0, min(f, 1.0)) |
| | |
| | index = int(f * length) |
| | |
| | index = min(index, length - 1) |
| | return index |
| |
|
| |
|
| | def select_weight_pair(f: float): |
| | return WEIGHT_PAIRS[float_to_pair_index(f)] |
| |
|
| |
|
| | def load_models(): |
| | controlnet_tile = ControlNetModel.from_pretrained( |
| | "lllyasviel/control_v11f1e_sd15_tile", |
| | torch_dtype=torch.float16, |
| | ) |
| |
|
| | controlnet_brightness = ControlNetModel.from_pretrained( |
| | "ioclab/control_v1p_sd15_brightness", |
| | torch_dtype=torch.float16, |
| | ) |
| |
|
| | pipe = StableDiffusionControlNetPipeline.from_pretrained( |
| | MODEL_ID, |
| | controlnet=[ |
| | controlnet_tile, |
| | controlnet_brightness, |
| | ], |
| | torch_dtype=torch.float16, |
| | cache_dir="cache", |
| | |
| | ) |
| |
|
| | pipe.to("cuda") |
| |
|
| | pipe.scheduler = EulerAncestralDiscreteScheduler.from_config(pipe.scheduler.config) |
| | pipe.enable_vae_slicing() |
| |
|
| | return pipe |
| |
|
| |
|
| | def resize_for_condition_image(input_image, resolution: int): |
| | input_image = input_image.convert("RGB") |
| | W, H = input_image.size |
| | k = float(resolution) / min(H, W) |
| | H *= k |
| | W *= k |
| | H = int(round(H / 64.0)) * 64 |
| | W = int(round(W / 64.0)) * 64 |
| | img = input_image.resize((W, H), resample=Image.LANCZOS) |
| | return img |
| |
|
| |
|
| | def generate_qr_code(content: str): |
| | qrcode_generator = qrcode.QRCode( |
| | version=1, |
| | error_correction=qrcode.ERROR_CORRECT_H, |
| | box_size=10, |
| | border=2, |
| | ) |
| | qrcode_generator.clear() |
| | qrcode_generator.add_data(content) |
| | qrcode_generator.make(fit=True) |
| | img = qrcode_generator.make_image(fill_color="black", back_color="white") |
| | img = resize_for_condition_image(img, 768) |
| | return img |
| |
|
| |
|
| | def image_to_base64(image): |
| | buffered = BytesIO() |
| | image.save(buffered, format="PNG") |
| | return base64.b64encode(buffered.getvalue()).decode("utf-8") |
| |
|
| |
|
| | def generate_image_with_conditioning_scale(**inputs): |
| | styles = inputs["styles"] |
| | pair = inputs["pair"] |
| | pipe = inputs["pipe"] |
| | qr_image = inputs["qr_image"] |
| | generator = inputs["generator"] |
| |
|
| | images = pipe( |
| | prompt=styles, |
| | negative_prompt=[""] * len(styles), |
| | width=WIDTH, |
| | height=HEIGHT, |
| | guidance_scale=7.0, |
| | generator=generator, |
| | num_inference_steps=25, |
| | num_images_per_prompt=2, |
| | controlnet_conditioning_scale=pair, |
| | image=[qr_image] * 2, |
| | ).images |
| |
|
| | |
| |
|
| | return { |
| | "fields": [ |
| | { |
| | "name": "output", |
| | "type": "Image", |
| | "value": [ |
| | f"data:image/png;base64,{image_to_base64(image)}" |
| | for image in images |
| | ], |
| | } |
| | ] |
| | } |
| |
|
| |
|
| | def generate_image(pipe, inputs): |
| | styles = inputs["styles"] |
| | if isinstance(styles, str): |
| | styles = [styles] |
| | if len(styles) == 1: |
| | styles = styles * 5 |
| | content = inputs["content"] |
| | art_scale = inputs["art_scale"] |
| |
|
| | with torch.inference_mode(): |
| | with torch.autocast("cuda"): |
| |
|
| | qr_image = generate_qr_code(content) |
| | generator = torch.Generator() |
| | pair = select_weight_pair(art_scale) |
| | return generate_image_with_conditioning_scale( |
| | styles=styles, |
| | pair=pair, |
| | pipe=pipe, |
| | qr_image=qr_image, |
| | generator=generator, |
| | ) |
| |
|
| |
|
| | class EndpointHandler: |
| | def __init__(self, path=""): |
| | self._model = load_models() |
| |
|
| | def __call__(self, model_input: Dict[str, Any]): |
| | images = generate_image(self._model, model_input) |
| | return images |
| |
|