File size: 11,870 Bytes
3330f45
1fc8d06
b83d18f
3330f45
 
 
 
 
51276d0
6400e55
3330f45
 
 
 
51276d0
3330f45
184daa2
0e6feda
5cb3999
6400e55
c3ae240
6400e55
a3657ef
c3ae240
 
6400e55
51276d0
b83d18f
 
3330f45
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c3ae240
a3657ef
 
 
3330f45
 
 
 
 
 
 
 
51276d0
a3657ef
3330f45
 
 
 
 
 
51276d0
 
3330f45
 
 
 
6400e55
 
 
ee61c84
 
 
 
 
c3ae240
 
 
ee61c84
3330f45
c3ae240
 
 
a3657ef
c3ae240
b83d18f
c3ae240
 
6400e55
 
 
 
 
 
 
 
c3ae240
 
ee61c84
c3ae240
 
6400e55
 
 
 
 
 
 
 
c3ae240
 
6a497cb
51276d0
c3ae240
 
 
 
 
 
51276d0
e8943d1
6400e55
a3657ef
51276d0
 
a3657ef
e8943d1
 
 
 
a3657ef
c3ae240
e8943d1
 
ee61c84
51276d0
e8943d1
ee61c84
a3657ef
 
c3ae240
51276d0
 
ee61c84
 
 
51276d0
 
c3ae240
51276d0
a3657ef
51276d0
 
 
 
c3ae240
51276d0
 
 
 
e8943d1
 
a3657ef
 
 
e8943d1
51276d0
 
e8943d1
 
51276d0
e8943d1
 
51276d0
9c3d0dc
51276d0
 
3330f45
a3657ef
c3ae240
d28e6fc
6400e55
c3ae240
 
 
 
 
3330f45
b83d18f
6400e55
c3ae240
a3657ef
6400e55
c3ae240
6400e55
 
c3ae240
6400e55
c3ae240
6400e55
 
c3ae240
a3657ef
c3ae240
 
 
a3657ef
c3ae240
 
a3657ef
c3ae240
 
 
a3657ef
0e6feda
d28e6fc
c3ae240
 
d28e6fc
 
c3ae240
 
a3657ef
6400e55
c3ae240
6400e55
c3ae240
 
 
 
 
6400e55
c3ae240
a3657ef
c3ae240
 
 
a3657ef
c3ae240
 
a3657ef
c3ae240
 
 
a3657ef
c3ae240
 
 
 
d28e6fc
 
3330f45
6a497cb
 
a3657ef
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
import os, gc, random, re
import gradio as gr
import torch, spaces
from PIL import Image, ImageFilter
import numpy as np
import qrcode
from qrcode.constants import ERROR_CORRECT_H
from diffusers import (
    StableDiffusionControlNetPipeline,
    StableDiffusionControlNetImg2ImgPipeline,   # for Hi-Res Fix
    ControlNetModel,
    DPMSolverMultistepScheduler,
)

# Quiet matplotlib cache warning on Spaces
os.environ.setdefault("MPLCONFIGDIR", "/tmp/mpl")



# ---- base models for the two tabs ----
BASE_MODELS = {
    "stable-diffusion-v1-5": "runwayml/stable-diffusion-v1-5",
    "dream":    "Lykon/dreamshaper-8",
}

# ControlNet (QR Monster v2 for SD15)
CN_QRMON = "monster-labs/control_v1p_sd15_qrcode_monster"
DTYPE = torch.float16

# ---------- helpers ----------
def snap8(x: int) -> int:
    x = max(256, min(1024, int(x)))
    return x - (x % 8)

def normalize_color(c):
    if c is None: return "white"
    if isinstance(c, (tuple, list)):
        r, g, b = (int(max(0, min(255, round(float(x))))) for x in c[:3]); return (r, g, b)
    if isinstance(c, str):
        s = c.strip()
        if s.startswith("#"): return s
        m = re.match(r"rgba?\(\s*([0-9.]+)\s*,\s*([0-9.]+)\s*,\s*([0-9.]+)", s, re.IGNORECASE)
        if m:
            r = int(max(0, min(255, round(float(m.group(1))))))
            g = int(max(0, min(255, round(float(m.group(2))))))
            b = int(max(0, min(255, round(float(m.group(3))))))
            return (r, g, b)
        return s
    return "white"

def make_qr(url="https://example.com", size=768, border=12, back_color="#FFFFFF", blur_radius=0.0):
    """
    IMPORTANT for Method 1: give ControlNet a sharp, black-on-WHITE QR (no blur).
    """
    qr = qrcode.QRCode(version=None, error_correction=ERROR_CORRECT_H, box_size=10, border=int(border))
    qr.add_data(url.strip()); qr.make(fit=True)
    img = qr.make_image(fill_color="black", back_color=normalize_color(back_color)).convert("RGB")
    img = img.resize((int(size), int(size)), Image.NEAREST)
    if blur_radius and blur_radius > 0:
        img = img.filter(ImageFilter.GaussianBlur(radius=float(blur_radius)))
    return img

def enforce_qr_contrast(stylized: Image.Image, qr_img: Image.Image, strength: float = 0.0, feather: float = 1.0) -> Image.Image:
    """Optional gentle repair. Default OFF for Method 1."""
    if strength <= 0: return stylized
    q = qr_img.convert("L")
    black_mask = q.point(lambda p: 255 if p < 128 else 0).filter(ImageFilter.GaussianBlur(radius=float(feather)))
    black = np.asarray(black_mask, dtype=np.float32) / 255.0
    white = 1.0 - black
    s = np.asarray(stylized.convert("RGB"), dtype=np.float32) / 255.0
    s = s * (1.0 - float(strength) * black[..., None])
    s = s + (1.0 - s) * (float(strength) * 0.85 * white[..., None])
    s = np.clip(s, 0.0, 1.0)
    return Image.fromarray((s * 255.0).astype(np.uint8), mode="RGB")

# ---------- lazy pipelines (CPU-offloaded for ZeroGPU) ----------
_CN = None                 # shared ControlNet QR Monster
_CN_TXT2IMG = {}           # per-base-model txt2img pipes
_CN_IMG2IMG = {}           # per-base-model img2img pipes

def _base_scheduler_for(pipe):
    pipe.scheduler = DPMSolverMultistepScheduler.from_config(
        pipe.scheduler.config, use_karras_sigmas=True, algorithm_type="dpmsolver++"
    )
    pipe.enable_attention_slicing()
    pipe.enable_vae_slicing()
    pipe.enable_model_cpu_offload()
    return pipe

def get_cn():
    global _CN
    if _CN is None:
        _CN = ControlNetModel.from_pretrained(CN_QRMON, torch_dtype=DTYPE, use_safetensors=True)
    return _CN

def get_qrmon_txt2img_pipe(model_id: str):
    if model_id not in _CN_TXT2IMG:
        pipe = StableDiffusionControlNetPipeline.from_pretrained(
            model_id,
            controlnet=get_cn(),
            torch_dtype=DTYPE,
            safety_checker=None,
            use_safetensors=True,
            low_cpu_mem_usage=True,
        )
        _CN_TXT2IMG[model_id] = _base_scheduler_for(pipe)
    return _CN_TXT2IMG[model_id]

def get_qrmon_img2img_pipe(model_id: str):
    if model_id not in _CN_IMG2IMG:
        pipe = StableDiffusionControlNetImg2ImgPipeline.from_pretrained(
            model_id,
            controlnet=get_cn(),
            torch_dtype=DTYPE,
            safety_checker=None,
            use_safetensors=True,
            low_cpu_mem_usage=True,
        )
        _CN_IMG2IMG[model_id] = _base_scheduler_for(pipe)
    return _CN_IMG2IMG[model_id]

# -------- Method 1: QR control model in text-to-image (+ optional Hi-Res Fix) --------
def _qr_txt2img_core(model_id: str,
                     url: str, style_prompt: str, negative: str,
                     steps: int, cfg: float, size: int, border: int,
                     qr_weight: float, seed: int,
                     use_hires: bool, hires_upscale: float, hires_strength: float,
                     repair_strength: float, feather: float):

    s = snap8(size)

    # Control image: crisp black-on-white QR
    qr_img = make_qr(url=url, size=s, border=int(border), back_color="#FFFFFF", blur_radius=0.0)

    # Seed / generator
    if int(seed) < 0:
        seed = random.randint(0, 2**31 - 1)
    gen = torch.Generator(device="cuda").manual_seed(int(seed))

    # --- Stage A: txt2img with ControlNet
    pipe = get_qrmon_txt2img_pipe(model_id)
    if torch.cuda.is_available(): torch.cuda.empty_cache()
    gc.collect()
    with torch.autocast(device_type="cuda", dtype=DTYPE):
        out = pipe(
            prompt=str(style_prompt),
            negative_prompt=str(negative or ""),
            image=qr_img,                                   # control image for txt2img
            controlnet_conditioning_scale=float(qr_weight), # ~1.0–1.2 works well
            control_guidance_start=0.0,
            control_guidance_end=1.0,
            num_inference_steps=int(steps),
            guidance_scale=float(cfg),
            width=s, height=s,
            generator=gen,
        )
    lowres = out.images[0]
    lowres = enforce_qr_contrast(lowres, qr_img, strength=float(repair_strength), feather=float(feather))

    # --- Optional Stage B: Hi-Res Fix (img2img with same QR)
    final = lowres
    if use_hires:
        up = max(1.0, min(2.0, float(hires_upscale)))
        W = snap8(int(s * up)); H = W
        pipe2 = get_qrmon_img2img_pipe(model_id)
        if torch.cuda.is_available(): torch.cuda.empty_cache()
        gc.collect()
        with torch.autocast(device_type="cuda", dtype=DTYPE):
            out2 = pipe2(
                prompt=str(style_prompt),
                negative_prompt=str(negative or ""),
                image=lowres,                      # init image
                control_image=qr_img,              # same QR
                strength=float(hires_strength),    # ~0.7 like "Hires Fix"
                controlnet_conditioning_scale=float(qr_weight),
                control_guidance_start=0.0,
                control_guidance_end=1.0,
                num_inference_steps=int(steps),
                guidance_scale=float(cfg),
                width=W, height=H,
                generator=gen,
            )
        final = out2.images[0]

    final = enforce_qr_contrast(final, qr_img, strength=float(repair_strength), feather=float(feather))
    return final, lowres, qr_img

# Wrappers for each tab (so Gradio can bind without passing the model id)
@spaces.GPU(duration=120)
def qr_txt2img_sd15(*args):
    return _qr_txt2img_core(BASE_MODELS["stable-diffusion-v1-5"], *args)

@spaces.GPU(duration=120)
def qr_txt2img_dream(*args):
    return _qr_txt2img_core(BASE_MODELS["dream"], *args)

# ---------- UI ----------
with gr.Blocks() as demo:
    gr.Markdown("# ZeroGPU • Method 1: QR Control (two base models)")

    # ---- Tab 1: stable-diffusion-v1-5 (anime/illustration) ----
    with gr.Tab("stable-diffusion-v1-5"):
        url1       = gr.Textbox(label="URL/Text", value="http://www.mybirdfire.com")
        s_prompt1  = gr.Textbox(label="Style prompt", value="japanese painting, elegant shrine and torii, distant mount fuji, autumn maple trees, warm sunlight, 1girl in kimono, highly detailed, intricate patterns, anime key visual, dramatic composition")
        s_negative1= gr.Textbox(label="Negative prompt", value="ugly, low quality, blurry, nsfw, watermark, text, low contrast, deformed, extra digits")
        size1      = gr.Slider(384, 1024, value=512, step=64, label="Canvas (px)")
        steps1     = gr.Slider(10, 50, value=20, step=1, label="Steps")
        cfg1       = gr.Slider(1.0, 12.0, value=7.0, step=0.1, label="CFG")
        border1    = gr.Slider(2, 16, value=4, step=1, label="QR border (quiet zone)")
        qr_w1      = gr.Slider(0.6, 1.6, value=1.5, step=0.05, label="QR control weight")
        seed1      = gr.Number(value=-1, precision=0, label="Seed (-1 random)")

        use_hires1 = gr.Checkbox(value=True, label="Hi-Res Fix (img2img upscale)")
        hires_up1  = gr.Slider(1.0, 2.0, value=2.0, step=0.25, label="Hi-Res upscale (×)")
        hires_str1 = gr.Slider(0.3, 0.9, value=0.7, step=0.05, label="Hi-Res denoise strength")

        repair1    = gr.Slider(0.0, 1.0, value=0.0, step=0.05, label="Post repair strength (optional)")
        feather1   = gr.Slider(0.0, 3.0, value=1.0, step=0.1, label="Repair feather (px)")

        final_img1 = gr.Image(label="Final (or Hi-Res) image")
        low_img1   = gr.Image(label="Low-res (Stage A) preview")
        ctrl_img1  = gr.Image(label="Control QR used")

        gr.Button("Generate with stable-diffusion-v1-5").click(
            qr_txt2img_sd15,
            [url1, s_prompt1, s_negative1, steps1, cfg1, size1, border1, qr_w1, seed1,
             use_hires1, hires_up1, hires_str1, repair1, feather1],
            [final_img1, low_img1, ctrl_img1],
             api_name="qr_txt2img_sd15"
        )

    # ---- Tab 2: DreamShaper (general art/painterly) ----
    with gr.Tab("DreamShaper 8"):
        url2       = gr.Textbox(label="URL/Text", value="http://www.mybirdfire.com")
        s_prompt2  = gr.Textbox(label="Style prompt", value="ornate baroque palace interior, gilded details, chandeliers, volumetric light, ultra detailed, cinematic")
        s_negative2= gr.Textbox(label="Negative prompt", value="lowres, low contrast, blurry, jpeg artifacts, watermark, text, bad anatomy")
        size2      = gr.Slider(384, 1024, value=512, step=64, label="Canvas (px)")
        steps2     = gr.Slider(10, 50, value=24, step=1, label="Steps")
        cfg2       = gr.Slider(1.0, 12.0, value=6.8, step=0.1, label="CFG")
        border2    = gr.Slider(2, 16, value=8, step=1, label="QR border (quiet zone)")
        qr_w2      = gr.Slider(0.6, 1.6, value=1.5, step=0.05, label="QR control weight")
        seed2      = gr.Number(value=-1, precision=0, label="Seed (-1 random)")

        use_hires2 = gr.Checkbox(value=True, label="Hi-Res Fix (img2img upscale)")
        hires_up2  = gr.Slider(1.0, 2.0, value=2.0, step=0.25, label="Hi-Res upscale (×)")
        hires_str2 = gr.Slider(0.3, 0.9, value=0.7, step=0.05, label="Hi-Res denoise strength")

        repair2    = gr.Slider(0.0, 1.0, value=0.0, step=0.05, label="Post repair strength (optional)")
        feather2   = gr.Slider(0.0, 3.0, value=1.0, step=0.1, label="Repair feather (px)")

        final_img2 = gr.Image(label="Final (or Hi-Res) image")
        low_img2   = gr.Image(label="Low-res (Stage A) preview")
        ctrl_img2  = gr.Image(label="Control QR used")

        gr.Button("Generate with DreamShaper 8").click(
            qr_txt2img_dream,
            [url2, s_prompt2, s_negative2, steps2, cfg2, size2, border2, qr_w2, seed2,
             use_hires2, hires_up2, hires_str2, repair2, feather2],
            [final_img2, low_img2, ctrl_img2],
            api_name="qr_txt2img_dream" 
        )

if __name__ == "__main__":
    demo.queue(max_size=12).launch()