File size: 9,259 Bytes
0b00c74
d2794b1
ca86cf6
8b44d8d
 
ca86cf6
 
 
 
b2cfd5f
464ec84
ca86cf6
b2cfd5f
 
ca86cf6
 
 
 
 
 
 
 
 
 
2342c94
 
ca86cf6
 
 
 
 
464ec84
ca86cf6
464ec84
 
 
 
 
 
 
 
 
 
 
 
ca86cf6
 
73c438e
 
b2cfd5f
 
 
 
73c438e
ca86cf6
 
2d11242
ca86cf6
464ec84
 
 
ca86cf6
 
464ec84
ca86cf6
bd9d89a
ca86cf6
4dfce87
e9d8edd
ca86cf6
 
 
 
fbe5687
464ec84
2342c94
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ca86cf6
 
 
bd9d89a
 
 
 
 
 
 
 
 
 
73c438e
bd9d89a
 
 
73c438e
 
 
 
ca86cf6
73c438e
ca86cf6
 
bd9d89a
ca86cf6
 
 
 
 
 
2880299
ca86cf6
7cfd7ed
ca86cf6
 
 
 
 
8b44d8d
73c438e
 
 
 
 
 
 
b2cfd5f
 
 
73c438e
 
 
 
 
 
 
 
ca86cf6
8fe2131
ca86cf6
bd9d89a
125c82c
bd9d89a
 
 
 
8b44d8d
ca86cf6
 
8b44d8d
ca86cf6
 
d2794b1
ca86cf6
 
2d11242
ca86cf6
 
74ae0b4
0b00c74
bd9d89a
 
ca86cf6
 
 
 
 
 
 
bd9d89a
 
 
 
 
ca86cf6
 
bd9d89a
 
 
 
 
 
 
 
 
 
 
73c438e
 
bd9d89a
 
 
73c438e
bd9d89a
 
73c438e
bd9d89a
 
73c438e
bd9d89a
ca86cf6
73c438e
 
0b00c74
ca86cf6
0b00c74
ca86cf6
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
import gradio as gr

import os
import torch

import numpy as np

device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')

from transformers import AutoModelForImageClassification, BlipImageProcessor
from diffusers import DiffusionPipeline, AutoencoderKL
import torchvision.transforms as transforms
from huggingface_hub import hf_hub_download
from safetensors import safe_open

from copy import deepcopy
from collections import OrderedDict

import requests
import json

from PIL import Image, ImageEnhance
import base64
import io
import random
import math

class BZHStableSignatureDemo(object):

    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)

        self.pipe = DiffusionPipeline.from_pretrained("stabilityai/sdxl-turbo", torch_dtype=torch.float16, variant="fp16").to("cuda")

        # disable invisible-watermark
        self.pipe.watermark = None

        # save the original VAE
        decoders = OrderedDict([("no watermark", self.pipe.vae)])

        # load the patched VAEs
        for name in ("weak", "medium", "strong", "extreme"):
            vae = AutoencoderKL.from_pretrained(f"imatag/stable-signature-bzh-sdxl-vae-{name}", torch_dtype=torch.float16).to("cuda")
            decoders[name] = vae

        self.decoders = decoders

        # load the proxy detector
        self.detector_image_processor = BlipImageProcessor.from_pretrained("imatag/stable-signature-bzh-detector-resnet18")
        self.detector_model = AutoModelForImageClassification.from_pretrained("imatag/stable-signature-bzh-detector-resnet18")
        calibration = hf_hub_download("imatag/stable-signature-bzh-detector-resnet18", filename="calibration.safetensors")
        with safe_open(calibration, framework="pt") as f:
            self.calibration_logits = f.get_tensor("logits")

    def generate(self, mode, seed, prompt):
        generator = torch.Generator(device=device)
        torch.manual_seed(seed)

        # load the patched VAE
        vae = self.decoders[mode]
        self.pipe.vae = vae

        output = self.pipe(prompt, num_inference_steps=4, guidance_scale=0.0, output_type="pil")
        return output.images[0]

    def attack(self, img, jpeg_compression, downscale, crop, saturation, brightness, contrast):

        img = img.convert("RGB")
        
        # attack
        if downscale != 1:
            size = img.size
            size = (int(size[0] / downscale), int(size[1] / downscale))
            img = img.resize(size, Image.Resampling.LANCZOS)

        if crop != 0:
            width, height = img.size
            area = width * height
            log_rmin = math.log(0.5)
            log_rmax = math.log(2.0)
            for _ in range(10):
                target_area = area * (1 - crop)
                aspect_ratio = math.exp(random.random() * (log_rmax - log_rmin) + log_rmin)
                w = int(round(math.sqrt(target_area * aspect_ratio)))
                h = int(round(math.sqrt(target_area / aspect_ratio)))
                if 0 < w <= width and 0 < h <= height:
                    top = random.randint(0, height - h + 1)
                    left = random.randint(0, width - w + 1)
                    img = img.crop((left, top, left+w, top+h))
                    break

        converter = ImageEnhance.Color(img)
        img = converter.enhance(saturation)

        converter = ImageEnhance.Brightness(img)
        img = converter.enhance(brightness)

        converter = ImageEnhance.Contrast(img)
        img = converter.enhance(contrast)

        # JPEG attack
        mf = io.BytesIO()
        img.save(mf, format='JPEG', quality=jpeg_compression)
        filesize = mf.tell()
        mf.seek(0)
        img = Image.open(mf)

        image_info = "resolution: %dx%d" % img.size
        image_info += " JPEG file size: %d" % filesize

        return img, image_info
        
    def detect_api(self, img):
        # send to detection API and apply JPEG compression attack
        mf = io.BytesIO()
        img.save(mf, format='PNG')
        b64 = base64.b64encode(mf.getvalue())
        data = {
            'image': b64.decode('utf8')
        }
        
        headers = {}
        api_key = os.getenv('BZH_API_KEY')
        if api_key:
            headers['x-api-key'] = api_key
        response = requests.post('https://bzh.imatag.com/bzh/api/v1.0/detect',
                                 json=data, headers=headers)
        response.raise_for_status()
        data = response.json()
        pvalue = data['p-value']

        return pvalue

    def detect_proxy(self, img):
        img = img.convert("RGB")
        inputs = self.detector_image_processor(img, return_tensors="pt")

        with torch.no_grad():
            logit = self.detector_model(**inputs).logits[...,0]
            pvalue = (1 + torch.searchsorted(self.calibration_logits, logit)) / self.calibration_logits.shape[0]
            pvalue = pvalue.item()

        return pvalue
 
    def detect(self, img, detection_method):
        if detection_method == "API":
            pvalue = self.detect_api(img)
        else:
            pvalue = self.detect_proxy(img)
        result = "No watermark detected."
        rpv = 10**int(math.log10(pvalue))
        if pvalue < 1e-3:
            result = "Watermark detected with low confidence" # (p-value<%.0e)" % rpv
        if pvalue < 1e-6:
            result = "Watermark detected with high confidence" # (p-value<%.0e)" % rpv
        score = min(int(-math.log10(pvalue)), 10)
        #print("score = ", score)
        return { result: score/10 }

def interface():
    prompt = "sailing ship in storm by Rembrandt"

    backend = BZHStableSignatureDemo()
    decoders = list(backend.decoders.keys())

    with gr.Blocks() as demo:
        gr.Markdown("""# Watermarked SDXL-Turbo demo
        This demo brought to you by [IMATAG](https://www.imatag.com/) presents watermarking of images generated via [StableDiffusion XL Turbo](https://huggingface.co/stabilityai/sdxl-turbo).
        Using the method presented in [StableSignature](https://ai.meta.com/blog/stable-signature-watermarking-generative-ai/),
        the VAE decoder of StableDiffusion is fine-tuned to produce images including a specific invisible watermark. We combined
        this method with a demo version of [IMATAG](https://www.imatag.com/)'s in-house decoder. The watermarking system operates in zero-bit mode for improved robustness.""")

        gr.Markdown("""## 1. Generate
        Select a watermarking strength and generate images with StableDiffusion-XL Turbo from prompt and seed as usual.""")
        with gr.Row():
            inp = gr.Textbox(label="Prompt", value=prompt)
            seed = gr.Number(label="Seed", precision=0)
            mode = gr.Dropdown(choices=decoders, label="Watermark strength", value="medium")
        with gr.Row():
            btn1 = gr.Button("Generate")
        with gr.Row():
            watermarked_image = gr.Image(type="pil", width=512, height=512, sources=[], interactive=False)

        gr.Markdown("""## 2. Edit
        With these controls you may alter the generated image before detection. You may also upload your own edited image instead.""")
        with gr.Row():
            with gr.Column():
                with gr.Row():
                    downscale = gr.Slider(1, 3, value=1, step=0.1, label="Downscale ratio")
                    crop = gr.Slider(0, 0.9, value=0, step=0.01, label="Random crop ratio")
                with gr.Row():
                    brightness = gr.Slider(0, 2, value=1, step=0.1, label="Brightness")
                    contrast = gr.Slider(0, 2, value=1, step=0.1, label="Contrast")
                with gr.Row():
                    saturation = gr.Slider(0, 2, value=1, step=0.1, label="Color saturation")
                    jpeg_compression = gr.Slider(value=100, step=5, label="JPEG quality")
                btn2 = gr.Button("Edit")
        with gr.Row():
            attacked_image = gr.Image(type="pil", width=512, sources=['upload', 'clipboard'])
        with gr.Row():
            image_info_label = gr.Label(label="Image info")

        gr.Markdown("""## 3. Detect
        Detect the watermark on the altered image. Watermark may not be detected if the image is altered too strongly.
        You may choose to detect with our fast [proxy model](https://huggingface.co/imatag/stable-signature-bzh-detector-resnet18), or via API for improved robustness.
        """)
        with gr.Row():
            detection_method = gr.Dropdown(choices=["proxy model", "API"], label="Detection method", value="proxy model")
            btn3 = gr.Button("Detect")
        with gr.Row():
            detection_label = gr.Label(label="Detection info")

        btn1.click(fn=backend.generate, inputs=[mode, seed, inp], outputs=[watermarked_image], api_name="generate")
        btn2.click(fn=backend.attack, inputs=[watermarked_image, jpeg_compression, downscale, crop, saturation, brightness, contrast], outputs=[attacked_image, image_info_label], api_name="attack")
        btn3.click(fn=backend.detect, inputs=[attacked_image, detection_method], outputs=[detection_label], api_name="detect")

    return demo

if __name__ == '__main__':
    demo = interface()
    demo.launch()