xiaoyao9184 commited on
Commit
2718c61
1 Parent(s): 2e6ccc7

Support gradio

Browse files
Files changed (6) hide show
  1. .gitignore +2 -0
  2. .vscode/launch.json +16 -0
  3. README.md +3 -2
  4. app.py +44 -0
  5. gradio_app.py +388 -0
  6. requirements.txt +4 -0
.gitignore ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ *.pyc
2
+ watermark-anything
.vscode/launch.json ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "version": "0.2.0",
3
+ "configurations": [
4
+ {
5
+ "name": "debugpy: app.py",
6
+ "type": "debugpy",
7
+ "request": "launch",
8
+ "program": "${workspaceFolder}/app.py",
9
+ "console": "integratedTerminal",
10
+ "env": {
11
+ // "HF_ENDPOINT": "https://hf-mirror.com"
12
+ },
13
+ "justMyCode": false
14
+ }
15
+ ]
16
+ }
README.md CHANGED
@@ -1,10 +1,11 @@
1
  ---
2
  title: Watermark Anything
3
- emoji:
4
  colorFrom: gray
5
  colorTo: indigo
6
  sdk: gradio
7
- sdk_version: 5.9.0
 
8
  app_file: app.py
9
  pinned: false
10
  license: apache-2.0
 
1
  ---
2
  title: Watermark Anything
3
+ emoji: 💧
4
  colorFrom: gray
5
  colorTo: indigo
6
  sdk: gradio
7
+ sdk_version: 5.8.0
8
+ python_version: '3.10.14'
9
  app_file: app.py
10
  pinned: false
11
  license: apache-2.0
app.py ADDED
@@ -0,0 +1,44 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import sys
3
+ import git
4
+ import subprocess
5
+ from huggingface_hub import hf_hub_download
6
+
7
+ REPO_URL = "https://github.com/facebookresearch/watermark-anything.git"
8
+ REPO_BRANCH = '88e3ae5d5866a7daaac167ea202a61a7d69ef590'
9
+ LOCAL_PATH = "./watermark-anything"
10
+ MODEL_ID = "xiaoyao9184/watermark-anything"
11
+
12
+ def install_src():
13
+ if not os.path.exists(LOCAL_PATH):
14
+ print(f"Cloning repository from {REPO_URL}...")
15
+ repo = git.Repo.clone_from(REPO_URL, LOCAL_PATH)
16
+ repo.git.checkout(REPO_BRANCH)
17
+ else:
18
+ print(f"Repository already exists at {LOCAL_PATH}")
19
+
20
+ requirements_path = os.path.join(LOCAL_PATH, "requirements.txt")
21
+ if os.path.exists(requirements_path):
22
+ print("Installing requirements...")
23
+ subprocess.check_call(["pip", "install", "-r", requirements_path])
24
+ else:
25
+ print("No requirements.txt found.")
26
+
27
+ def install_model():
28
+ checkpoint_path = os.path.join(LOCAL_PATH, "checkpoints")
29
+ hf_hub_download(repo_id=MODEL_ID, filename='checkpoint.pth', local_dir=checkpoint_path)
30
+
31
+ # clone repo and download model
32
+ install_src()
33
+ install_model()
34
+
35
+ # change directory
36
+ print(f"Current Directory: {os.getcwd()}")
37
+ os.chdir(LOCAL_PATH)
38
+ print(f"New Directory: {os.getcwd()}")
39
+
40
+ # fix sys.path for import
41
+ sys.path.append(os.getcwd())
42
+
43
+ # run gradio
44
+ import gradio_app
gradio_app.py ADDED
@@ -0,0 +1,388 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+
3
+ import gradio as gr
4
+
5
+ import re
6
+ import string
7
+ import random
8
+ import os
9
+ import numpy as np
10
+ from PIL import Image
11
+ import torch
12
+ import torch.nn.functional as F
13
+ from torchvision import transforms
14
+
15
+
16
+ from watermark_anything.data.metrics import msg_predict_inference
17
+ from notebooks.inference_utils import (
18
+ load_model_from_checkpoint,
19
+ default_transform,
20
+ unnormalize_img,
21
+ create_random_mask,
22
+ plot_outputs,
23
+ msg2str,
24
+ torch_to_np,
25
+ multiwm_dbscan
26
+ )
27
+
28
+ # Device configuration
29
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
30
+
31
+ # Constants
32
+ proportion_masked = 0.5 # Proportion of image to be watermarked
33
+ epsilon = 1 # min distance between decoded messages in a cluster
34
+ min_samples = 500 # min number of pixels in a 256x256 image to form a cluster
35
+
36
+ # Color map for visualization
37
+ color_map = {
38
+ -1: [0, 0, 0], # Black for -1
39
+ 0: [255, 0, 255], # ? for 0
40
+ 1: [255, 0, 0], # Red for 1
41
+ 2: [0, 255, 0], # Green for 2
42
+ 3: [0, 0, 255], # Blue for 3
43
+ 4: [255, 255, 0], # Yellow for 4
44
+ }
45
+
46
+ def load_wam():
47
+ # Load the model from the specified checkpoint
48
+ exp_dir = "checkpoints"
49
+ json_path = os.path.join(exp_dir, "params.json")
50
+ ckpt_path = os.path.join(exp_dir, 'checkpoint.pth')
51
+ wam = load_model_from_checkpoint(json_path, ckpt_path).to(device).eval()
52
+ return wam
53
+
54
+ def image_detect(img_pil: Image.Image) -> (torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor):
55
+ img_pt = default_transform(img_pil).unsqueeze(0).to(device) # [1, 3, H, W]
56
+
57
+ # Detect the watermark in the multi-watermarked image
58
+ preds = wam.detect(img_pt)["preds"] # [1, 33, 256, 256]
59
+ mask_preds = F.sigmoid(preds[:, 0, :, :]) # [1, 256, 256], predicted mask
60
+ mask_preds_res = F.interpolate(mask_preds.unsqueeze(1), size=(img_pt.shape[-2], img_pt.shape[-1]), mode="bilinear", align_corners=False) # [1, 1, H, W]
61
+ bit_preds = preds[:, 1:, :, :] # [1, 32, 256, 256], predicted bits
62
+
63
+ # positions has the cluster number at each pixel. can be upsaled back to the original size.
64
+ try:
65
+ centroids, positions = multiwm_dbscan(bit_preds, mask_preds, epsilon=epsilon, min_samples=min_samples)
66
+ centroids_pt = torch.stack(list(centroids.values()))
67
+ except (UnboundLocalError) as e:
68
+ print(f"Error while detecting watermark: {e}")
69
+ positions = None
70
+ centroids = None
71
+ centroids_pt = None
72
+
73
+ return img_pt, (mask_preds_res>0.5).float(), positions, centroids, centroids_pt
74
+
75
+ def image_embed(img_pil: Image.Image, wm_msgs: torch.Tensor, wm_masks: torch.Tensor) -> (torch.Tensor, torch.Tensor, torch.Tensor):
76
+ img_pt = default_transform(img_pil).unsqueeze(0).to(device) # [1, 3, H, W]
77
+
78
+ # Embed the watermark message into the image
79
+ # Mask to use. 1 values correspond to pixels where the watermark will be placed.
80
+ multi_wm_img = img_pt.clone()
81
+ for ii in range(len(wm_msgs)):
82
+ wm_msg, mask = wm_msgs[ii].unsqueeze(0), wm_masks[ii]
83
+ outputs = wam.embed(img_pt, wm_msg)
84
+ multi_wm_img = outputs['imgs_w'] * mask + multi_wm_img * (1 - mask)
85
+
86
+ torch.cuda.empty_cache()
87
+ return img_pt, multi_wm_img, wm_masks.sum(0)
88
+
89
+ def create_bounding_mask(img_size, boxes):
90
+ """Create a binary mask from bounding boxes.
91
+
92
+ Args:
93
+ img_size (tuple): Image size (height, width)
94
+ boxes (list): List of tuples (x1, y1, x2, y2) defining bounding boxes
95
+
96
+ Returns:
97
+ torch.Tensor: Binary mask tensor
98
+ """
99
+ mask = torch.zeros(img_size)
100
+ for x1, y1, x2, y2 in boxes:
101
+ mask[y1:y2, x1:x2] = 1
102
+ return mask
103
+
104
+ def centroid_to_hex(centroid):
105
+ binary_int = 0
106
+ for bit in centroid:
107
+ binary_int = (binary_int << 1) | int(bit.item())
108
+ return format(binary_int, '08x')
109
+
110
+ # Load the model
111
+ wam = load_wam()
112
+
113
+ def detect_watermark(image):
114
+ if image is None:
115
+ return None, None, None, {"status": "error", "messages": [], "error": "No image provided"}
116
+
117
+ img_pil = Image.fromarray(image).convert("RGB")
118
+ det_img, pred, positions, centroids, centroids_pt = image_detect(img_pil)
119
+
120
+ # Convert tensor images to numpy for display
121
+ detected_img = torch_to_np(det_img.detach())
122
+ pred_mask = torch_to_np(pred.detach().repeat(1, 3, 1, 1))
123
+
124
+ # Create cluster visualization
125
+ if positions is not None:
126
+ resize_ori = transforms.Resize(det_img.shape[-2:])
127
+ rgb_image = torch.zeros((3, positions.shape[-1], positions.shape[-2]), dtype=torch.uint8)
128
+ for value, color in color_map.items():
129
+ mask_ = positions == value
130
+ for channel, color_value in enumerate(color):
131
+ rgb_image[channel][mask_.squeeze()] = color_value
132
+ rgb_image = resize_ori(rgb_image.float()/255)
133
+ cluster_viz = rgb_image.permute(1, 2, 0).numpy()
134
+
135
+ # Create message output as JSON
136
+ messages = []
137
+ for key in centroids.keys():
138
+ centroid_hex = centroid_to_hex(centroids[key])
139
+ centroid_hex_array = "-".join([centroid_hex[i:i+4] for i in range(0, len(centroid_hex), 4)])
140
+ messages.append({
141
+ "id": int(key),
142
+ "message": centroid_hex_array,
143
+ "color": color_map[key]
144
+ })
145
+ message_json = {
146
+ "status": "success",
147
+ "messages": messages,
148
+ "count": len(messages)
149
+ }
150
+ else:
151
+ cluster_viz = np.zeros_like(detected_img)
152
+ message_json = {
153
+ "status": "no_detection",
154
+ "messages": [],
155
+ "count": 0
156
+ }
157
+
158
+ return pred_mask, cluster_viz, message_json
159
+
160
+ def embed_watermark(image, wm_num, wm_type, wm_str, wm_loc):
161
+ if image is None:
162
+ return None, None, {
163
+ "status": "failure",
164
+ "messages": "No image provided"
165
+ }
166
+
167
+ if wm_type == "input":
168
+ if not re.match(r"^([0-9A-F]{4}-[0-9A-F]{4}-){%d}[0-9A-F]{4}-[0-9A-F]{4}$" % (wm_num-1), wm_str):
169
+ tip = "-".join([f"FFFF-{_}{_}{_}{_}" for _ in range(wm_num)])
170
+ return None, None, {
171
+ "status": "failure",
172
+ "messages": f"Invalid type input. Please use {tip}"
173
+ }
174
+
175
+ if wm_loc == "bounding":
176
+ if ROI_coordinates['clicks'] != wm_num * 2:
177
+ return None, None, {
178
+ "status": "failure",
179
+ "messages": "Invalid location input. Please draw at least %d bounding ROI" % (wm_num)
180
+ }
181
+
182
+ img_pil = Image.fromarray(image).convert("RGB")
183
+
184
+ # Generate watermark messages based on type
185
+ wm_msgs = []
186
+ if wm_type == "random":
187
+ chars = '-'.join(''.join(random.choice(string.hexdigits) for _ in range(4)) for _ in range(wm_num * 2))
188
+ wm_str = chars.lower()
189
+ wm_hex = wm_str.replace("-", "")
190
+ for i in range(0, len(wm_hex), 8):
191
+ chunk = wm_hex[i:i+8]
192
+ binary = bin(int(chunk, 16))[2:].zfill(32)
193
+ wm_msgs.append([int(b) for b in binary])
194
+ # Define a 32-bit message to be embedded into the images
195
+ wm_msgs = torch.tensor(wm_msgs, dtype=torch.float32).to(device)
196
+
197
+ # Create mask based on location type
198
+ wm_masks = None
199
+ if wm_loc == "random":
200
+ img_pt = default_transform(img_pil).unsqueeze(0).to(device)
201
+ # To ensure at least `proportion_masked %` of the width is randomly usable,
202
+ # otherwise, it is easy to enter an infinite loop and fail to find a usable width.
203
+ mask_percentage = img_pil.height / img_pil.width * proportion_masked / wm_num
204
+ wm_masks = create_random_mask(img_pt, num_masks=wm_num, mask_percentage=mask_percentage)
205
+ elif wm_loc == "bounding" and sections:
206
+ wm_masks = torch.zeros((len(sections), 1, img_pil.height, img_pil.width), dtype=torch.float32).to(device)
207
+ for idx, ((x_start, y_start, x_end, y_end), _) in enumerate(sections):
208
+ left = min(x_start, x_end)
209
+ right = max(x_start, x_end)
210
+ top = min(y_start, y_end)
211
+ bottom = max(y_start, y_end)
212
+ wm_masks[idx, 0, top:bottom, left:right] = 1
213
+
214
+
215
+ img_pt, embed_img_pt, embed_mask_pt = image_embed(img_pil, wm_msgs, wm_masks)
216
+
217
+ # Convert to numpy for display
218
+ img_np = torch_to_np(embed_img_pt.detach())
219
+ mask_np = torch_to_np(embed_mask_pt.detach().expand(3, -1, -1))
220
+ message_json = {
221
+ "status": "success",
222
+ "messages": wm_str
223
+ }
224
+ return img_np, mask_np, message_json
225
+
226
+
227
+
228
+ # ROI means Region Of Interest. It is the region where the user clicks
229
+ # to specify the location of the watermark.
230
+ ROI_coordinates = {
231
+ 'x_temp': 0,
232
+ 'y_temp': 0,
233
+ 'x_new': 0,
234
+ 'y_new': 0,
235
+ 'clicks': 0,
236
+ }
237
+
238
+ sections = []
239
+
240
+ def get_select_coordinates(img, evt: gr.SelectData, num):
241
+ if ROI_coordinates['clicks'] >= num * 2:
242
+ gr.Warning(f"Cant add more than {num} of Watermarks.")
243
+ return (img, sections)
244
+
245
+ # update new coordinates
246
+ ROI_coordinates['clicks'] += 1
247
+ ROI_coordinates['x_temp'] = ROI_coordinates['x_new']
248
+ ROI_coordinates['y_temp'] = ROI_coordinates['y_new']
249
+ ROI_coordinates['x_new'] = evt.index[0]
250
+ ROI_coordinates['y_new'] = evt.index[1]
251
+ # compare start end coordinates
252
+ x_start = ROI_coordinates['x_new'] if (ROI_coordinates['x_new'] < ROI_coordinates['x_temp']) else ROI_coordinates['x_temp']
253
+ y_start = ROI_coordinates['y_new'] if (ROI_coordinates['y_new'] < ROI_coordinates['y_temp']) else ROI_coordinates['y_temp']
254
+ x_end = ROI_coordinates['x_new'] if (ROI_coordinates['x_new'] > ROI_coordinates['x_temp']) else ROI_coordinates['x_temp']
255
+ y_end = ROI_coordinates['y_new'] if (ROI_coordinates['y_new'] > ROI_coordinates['y_temp']) else ROI_coordinates['y_temp']
256
+ if ROI_coordinates['clicks'] % 2 == 0:
257
+ sections[len(sections) - 1] = ((x_start, y_start, x_end, y_end), f"Mask {len(sections)}")
258
+ # both start and end point get
259
+ return (img, sections)
260
+ else:
261
+ point_width = int(img.shape[0]*0.05)
262
+ sections.append(((ROI_coordinates['x_new'], ROI_coordinates['y_new'],
263
+ ROI_coordinates['x_new'] + point_width, ROI_coordinates['y_new'] + point_width),
264
+ f"Click second point for Mask {len(sections) + 1}"))
265
+ return (img, sections)
266
+
267
+ def del_select_coordinates(img, evt: gr.SelectData):
268
+ del sections[evt.index]
269
+ # recreate section names
270
+ for i in range(len(sections)):
271
+ sections[i] = (sections[i][0], f"Mask {i + 1}")
272
+
273
+ # last section clicking second point not complete
274
+ if ROI_coordinates['clicks'] % 2 != 0:
275
+ if len(sections) == evt.index:
276
+ # delete last section
277
+ ROI_coordinates['clicks'] -= 1
278
+ else:
279
+ # recreate last section name for second point
280
+ ROI_coordinates['clicks'] -= 2
281
+ sections[len(sections) - 1] = (sections[len(sections) - 1][0], f"Click second point for Mask {len(sections) + 1}")
282
+ else:
283
+ ROI_coordinates['clicks'] -= 2
284
+
285
+ return (img[0], sections)
286
+
287
+ with gr.Blocks(title="Watermark Anything Demo") as demo:
288
+ gr.Markdown("""
289
+ # Watermark Anything Demo
290
+ This app demonstrates watermark detection and embedding using the Watermark Anything model.
291
+ Find the project [here](https://github.com/facebookresearch/watermark-anything).
292
+ """)
293
+
294
+ with gr.Tabs():
295
+ with gr.TabItem("Embed Watermark"):
296
+ with gr.Row():
297
+ with gr.Column():
298
+ embedding_img = gr.Image(label="Input Image", type="numpy")
299
+
300
+ with gr.Column():
301
+ embedding_num = gr.Slider(1, 5, value=1, step=1, label="Number of Watermarks")
302
+ embedding_type = gr.Radio(["random", "input"], value="random", label="Type", info="Type of watermarks")
303
+ embedding_str = gr.Textbox(label="Watermark Text", visible=False, show_copy_button=True)
304
+ embedding_loc = gr.Radio(["random", "bounding"], value="random", label="Location", info="Location of watermarks")
305
+
306
+ @gr.render(inputs=embedding_loc)
307
+ def show_split(wm_loc):
308
+ if wm_loc == "bounding":
309
+ embedding_box = gr.AnnotatedImage(
310
+ label="ROI",
311
+ color_map={
312
+ "ROI of Watermark embedding": "#9987FF",
313
+ "Click second point for ROI": "#f44336"}
314
+ )
315
+
316
+ embedding_img.select(
317
+ fn=get_select_coordinates,
318
+ inputs=[embedding_img, embedding_num],
319
+ outputs=embedding_box)
320
+ embedding_box.select(
321
+ fn=del_select_coordinates,
322
+ inputs=embedding_box,
323
+ outputs=embedding_box
324
+ )
325
+ else:
326
+ embedding_img.select()
327
+
328
+ embedding_btn = gr.Button("Embed Watermark")
329
+ marked_msg = gr.JSON(label="Marked Messages")
330
+ with gr.Row():
331
+ marked_image = gr.Image(label="Watermarked Image")
332
+ marked_mask = gr.Image(label="Position of the watermark")
333
+
334
+ def visible_text_label(embedding_type, embedding_num):
335
+ if embedding_type == "input":
336
+ tip = "-".join([f"FFFF-{_}{_}{_}{_}" for _ in range(embedding_num)])
337
+ return gr.update(visible=True, label=f"Watermark Text (Format: {tip})")
338
+ else:
339
+ return gr.update(visible=False)
340
+
341
+ def check_embedding_str(embedding_str, embedding_num):
342
+ if not re.match(r"^([0-9A-F]{4}-[0-9A-F]{4}-){%d}[0-9A-F]{4}-[0-9A-F]{4}$" % (embedding_num-1), embedding_str):
343
+ tip = "-".join([f"FFFF-{_}{_}{_}{_}" for _ in range(embedding_num)])
344
+ gr.Warning(f"Invalid format. Please use {tip}", duration=0)
345
+ return gr.update(interactive=False)
346
+ else:
347
+ return gr.update(interactive=True)
348
+
349
+ embedding_num.change(
350
+ fn=visible_text_label,
351
+ inputs=[embedding_type, embedding_num],
352
+ outputs=[embedding_str]
353
+ )
354
+ embedding_type.change(
355
+ fn=visible_text_label,
356
+ inputs=[embedding_type, embedding_num],
357
+ outputs=[embedding_str]
358
+ )
359
+ embedding_str.change(
360
+ fn=check_embedding_str,
361
+ inputs=[embedding_str, embedding_num],
362
+ outputs=[embedding_btn]
363
+ )
364
+
365
+ embedding_btn.click(
366
+ fn=embed_watermark,
367
+ inputs=[embedding_img, embedding_num, embedding_type, embedding_str, embedding_loc],
368
+ outputs=[marked_image, marked_mask, marked_msg]
369
+ )
370
+
371
+ with gr.TabItem("Detect Watermark"):
372
+ with gr.Row():
373
+ with gr.Column():
374
+ detecting_img = gr.Image(label="Input Image", type="numpy")
375
+ with gr.Column():
376
+ detecting_btn = gr.Button("Detect Watermark")
377
+ predicted_messages = gr.JSON(label="Detected Messages")
378
+ with gr.Row():
379
+ predicted_mask = gr.Image(label="Predicted Watermark Position")
380
+ predicted_cluster = gr.Image(label="Watermark Clusters")
381
+
382
+ detecting_btn.click(
383
+ fn=detect_watermark,
384
+ inputs=[detecting_img],
385
+ outputs=[predicted_mask, predicted_cluster, predicted_messages]
386
+ )
387
+
388
+ demo.launch()
requirements.txt ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ torch==2.5.1
2
+ GitPython==3.1.43
3
+ gradio==5.8.0
4
+ huggingface-hub==0.26.3