Vivien Chappelier commited on
Commit
73c438e
1 Parent(s): 125c82c

add option for proxy model

Browse files
Files changed (1) hide show
  1. app.py +35 -5
app.py CHANGED
@@ -7,6 +7,7 @@ import numpy as np
7
 
8
  device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
9
 
 
10
  from diffusers import DiffusionPipeline, AutoencoderKL
11
  import torchvision.transforms as transforms
12
 
@@ -42,6 +43,11 @@ class BZHStableSignatureDemo(object):
42
 
43
  self.decoders = decoders
44
 
 
 
 
 
 
45
  def generate(self, mode, seed, prompt):
46
  generator = torch.Generator(device=device)
47
  torch.manual_seed(seed)
@@ -91,12 +97,16 @@ class BZHStableSignatureDemo(object):
91
  # JPEG attack
92
  mf = io.BytesIO()
93
  img.save(mf, format='JPEG', quality=jpeg_compression)
 
94
  mf.seek(0)
95
  img = Image.open(mf)
96
 
97
- return img
 
 
 
98
 
99
- def detect(self, img):
100
  # send to detection API and apply JPEG compression attack
101
  mf = io.BytesIO()
102
  img.save(mf, format='PNG')
@@ -115,6 +125,22 @@ class BZHStableSignatureDemo(object):
115
  data = response.json()
116
  pvalue = data['p-value']
117
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
118
  result = "No watermark detected."
119
  rpv = 10**int(math.log10(pvalue))
120
  if pvalue < 1e-3:
@@ -165,18 +191,22 @@ def interface():
165
  btn2 = gr.Button("Edit")
166
  with gr.Row():
167
  attacked_image = gr.Image(type="pil", width=512, sources=['upload', 'clipboard'])
 
 
168
 
169
  gr.Markdown("""## 3. Detect
170
  Detect the watermark on the altered image. Watermark may not be detected if the image is altered too strongly.
 
171
  """)
172
  with gr.Row():
 
173
  btn3 = gr.Button("Detect")
174
  with gr.Row():
175
- detection_label = gr.Label(label="Detection info", show_label=False)
176
 
177
  btn1.click(fn=backend.generate, inputs=[mode, seed, inp], outputs=[watermarked_image], api_name="generate")
178
- btn2.click(fn=backend.attack, inputs=[watermarked_image, jpeg_compression, downscale, crop, saturation, brightness, contrast], outputs=[attacked_image], api_name="attack")
179
- btn3.click(fn=backend.detect, inputs=[attacked_image], outputs=[detection_label], api_name="detect")
180
 
181
  return demo
182
 
 
7
 
8
  device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
9
 
10
+ from transformers import AutoModel, BlipImageProcessor
11
  from diffusers import DiffusionPipeline, AutoencoderKL
12
  import torchvision.transforms as transforms
13
 
 
43
 
44
  self.decoders = decoders
45
 
46
+ # load the proxy detector
47
+ self.detector_image_processor = BlipImageProcessor.from_pretrained("imatag/stable-signature-bzh-detector-resnet18")
48
+ commit_hash = "584a7bc01dc0f02e53bf8b8b295717ed09ed7294"
49
+ self.detector_model = AutoModel.from_pretrained("imatag/stable-signature-bzh-detector-resnet18", trust_remote_code=True, revision=commit_hash)
50
+
51
  def generate(self, mode, seed, prompt):
52
  generator = torch.Generator(device=device)
53
  torch.manual_seed(seed)
 
97
  # JPEG attack
98
  mf = io.BytesIO()
99
  img.save(mf, format='JPEG', quality=jpeg_compression)
100
+ filesize = mf.tell()
101
  mf.seek(0)
102
  img = Image.open(mf)
103
 
104
+ image_info = "resolution: %dx%d" % img.size
105
+ image_info += " JPEG file size: %d" % filesize
106
+
107
+ return img, image_info
108
 
109
+ def detect_api(self, img):
110
  # send to detection API and apply JPEG compression attack
111
  mf = io.BytesIO()
112
  img.save(mf, format='PNG')
 
125
  data = response.json()
126
  pvalue = data['p-value']
127
 
128
+ return pvalue
129
+
130
+ def detect_proxy(self, img):
131
+ img = img.convert("RGB")
132
+ inputs = self.detector_image_processor(img, return_tensors="pt")
133
+
134
+ with torch.no_grad():
135
+ pvalue = torch.sigmoid(self.detector_model(**inputs).logits).item()
136
+
137
+ return pvalue
138
+
139
+ def detect(self, img, detection_method):
140
+ if detection_method == "API":
141
+ pvalue = self.detect_api(img)
142
+ else:
143
+ pvalue = self.detect_proxy(img)
144
  result = "No watermark detected."
145
  rpv = 10**int(math.log10(pvalue))
146
  if pvalue < 1e-3:
 
191
  btn2 = gr.Button("Edit")
192
  with gr.Row():
193
  attacked_image = gr.Image(type="pil", width=512, sources=['upload', 'clipboard'])
194
+ with gr.Row():
195
+ image_info_label = gr.Label(label="Image info")
196
 
197
  gr.Markdown("""## 3. Detect
198
  Detect the watermark on the altered image. Watermark may not be detected if the image is altered too strongly.
199
+ You may choose to detect with our fast [proxy model](https://huggingface.co/imatag/stable-signature-bzh-detector-resnet18), or via API for improved robustness.
200
  """)
201
  with gr.Row():
202
+ detection_method = gr.Dropdown(choices=["proxy model", "API"], label="Detection method", value="proxy model")
203
  btn3 = gr.Button("Detect")
204
  with gr.Row():
205
+ detection_label = gr.Label(label="Detection info")
206
 
207
  btn1.click(fn=backend.generate, inputs=[mode, seed, inp], outputs=[watermarked_image], api_name="generate")
208
+ btn2.click(fn=backend.attack, inputs=[watermarked_image, jpeg_compression, downscale, crop, saturation, brightness, contrast], outputs=[attacked_image, image_info_label], api_name="attack")
209
+ btn3.click(fn=backend.detect, inputs=[attacked_image, detection_method], outputs=[detection_label], api_name="detect")
210
 
211
  return demo
212