xiaoyao9184 commited on
Commit
077e8af
1 Parent(s): 5a2f8f1

Synced repo using 'sync_with_huggingface' Github Action

Browse files
Files changed (3) hide show
  1. gradio_app.py +240 -0
  2. gradio_run.py +7 -0
  3. requirements.txt +6 -0
gradio_app.py ADDED
@@ -0,0 +1,240 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import sys
3
+
4
+ if "APP_PATH" in os.environ:
5
+ app_path = os.path.abspath(os.environ["APP_PATH"])
6
+ if os.getcwd() != app_path:
7
+ # fix sys.path for import
8
+ os.chdir(app_path)
9
+ if app_path not in sys.path:
10
+ sys.path.append(app_path)
11
+
12
+ import gradio as gr
13
+
14
+ import torch
15
+ import torchaudio
16
+ import numpy as np
17
+ import matplotlib.pyplot as plt
18
+ import re
19
+ import random
20
+ import string
21
+ from audioseal import AudioSeal
22
+
23
+
24
+ # Load generator if not already loaded in reload mode
25
+ if 'generator' not in globals():
26
+ generator = AudioSeal.load_generator("audioseal_wm_16bits")
27
+
28
+ # Load detector if not already loaded in reload mode
29
+ if 'detector' not in globals():
30
+ detector = AudioSeal.load_detector("audioseal_detector_16bits")
31
+
32
+
33
+ def load_audio(file):
34
+ wav, sample_rate = torchaudio.load(file)
35
+ return wav, sample_rate
36
+
37
+ def generate_msg_pt_by_format_string(format_string, bytes_count):
38
+ msg_hex = format_string.replace("-", "")
39
+ hex_length = bytes_count * 2
40
+ binary_list = []
41
+ for i in range(0, len(msg_hex), hex_length):
42
+ chunk = msg_hex[i:i+hex_length]
43
+ binary = bin(int(chunk, 16))[2:].zfill(bytes_count * 8)
44
+ binary_list.append([int(b) for b in binary])
45
+ # torch.randint(0, 2, (1, 16), dtype=torch.int32)
46
+ msg_pt = torch.tensor(binary_list, dtype=torch.int32)
47
+ return msg_pt
48
+
49
+ def embed_watermark(audio, sr, msg):
50
+ # We add the batch dimension to the single audio to mimic the batch watermarking
51
+ original_audio = audio.unsqueeze(0)
52
+
53
+ watermark = generator.get_watermark(original_audio, sr, message=msg)
54
+
55
+ watermarked_audio = original_audio + watermark
56
+
57
+ # Alternatively, you can also call forward() function directly with different tune-down / tune-up rate
58
+ # watermarked_audio = generator(audios, sample_rate=sr, alpha=1)
59
+
60
+ return watermarked_audio
61
+
62
+ def generate_format_string_by_msg_pt(msg_pt, bytes_count):
63
+ hex_length = bytes_count * 2
64
+ binary_int = 0
65
+ for bit in msg_pt:
66
+ binary_int = (binary_int << 1) | int(bit.item())
67
+ hex_string = format(binary_int, f'0{hex_length}x')
68
+
69
+ split_hex = [hex_string[i:i + 4] for i in range(0, len(hex_string), 4)]
70
+ format_hex = "-".join(split_hex)
71
+
72
+ return hex_string, format_hex
73
+
74
+ def detect_watermark(audio, sr):
75
+ # We add the batch dimension to the single audio to mimic the batch watermarking
76
+ watermarked_audio = audio.unsqueeze(0)
77
+
78
+ result, message = detector.detect_watermark(watermarked_audio, sr)
79
+
80
+ # pred_prob is a tensor of size batch x 2 x frames, indicating the probability (positive and negative) of watermarking for each frame
81
+ # A watermarked audio should have pred_prob[:, 1, :] > 0.5
82
+ # message_prob is a tensor of size batch x 16, indicating of the probability of each bit to be 1.
83
+ # message will be a random tensor if the detector detects no watermarking from the audio
84
+ pred_prob, message_prob = detector(watermarked_audio, sr)
85
+
86
+ return result, message, pred_prob, message_prob
87
+
88
+ def get_waveform_and_specgram(batch_waveform, sample_rate):
89
+ waveform = batch_waveform.squeeze().detach().cpu().numpy()
90
+
91
+ num_frames = waveform.shape[-1]
92
+ time_axis = torch.arange(0, num_frames) / sample_rate
93
+
94
+ figure, (ax1, ax2) = plt.subplots(2, 1)
95
+
96
+ ax1.plot(time_axis, waveform, linewidth=1)
97
+ ax1.grid(True)
98
+ ax2.specgram(waveform, Fs=sample_rate)
99
+
100
+ figure.suptitle(f"Waveform and specgram")
101
+
102
+ return figure
103
+
104
+ def generate_hex_format_regex(bytes_count):
105
+ hex_length = bytes_count * 2
106
+ hex_string = 'F' * hex_length
107
+ split_hex = [hex_string[i:i + 4] for i in range(0, len(hex_string), 4)]
108
+ format_like = "-".join(split_hex)
109
+ regex_pattern = '^' + '-'.join([r'[0-9A-Fa-f]{4}'] * len(split_hex)) + '$'
110
+ return format_like, regex_pattern
111
+
112
+ def generate_hex_random_message(bytes_count):
113
+ hex_length = bytes_count * 2
114
+ hex_string = ''.join(random.choice(string.hexdigits) for _ in range(hex_length))
115
+ split_hex = [hex_string[i:i + 4] for i in range(0, len(hex_string), 4)]
116
+ random_str = "-".join(split_hex)
117
+ return random_str, "".join(split_hex)
118
+
119
+ with gr.Blocks(title="AudioSeal") as demo:
120
+ gr.Markdown("""
121
+ # AudioSeal Demo
122
+
123
+ Find the project [here](https://github.com/facebookresearch/audioseal.git).
124
+ """)
125
+
126
+ with gr.Tabs():
127
+ with gr.TabItem("Embed Watermark"):
128
+ with gr.Row():
129
+ with gr.Column():
130
+ embedding_aud = gr.Audio(label="Input Audio", type="filepath")
131
+ embedding_specgram = gr.Checkbox(label="Show specgram", value=False, info="Show debug information")
132
+
133
+ embedding_type = gr.Radio(["random", "input"], value="random", label="Type", info="Type of watermarks")
134
+
135
+ nbytes = int(generator.msg_processor.nbits / 8)
136
+ format_like, regex_pattern = generate_hex_format_regex(nbytes)
137
+ msg, _ = generate_hex_random_message(nbytes)
138
+ embedding_msg = gr.Textbox(
139
+ label=f"Message ({nbytes} bytes hex string)",
140
+ info=f"format like {format_like}",
141
+ value=msg,
142
+ interactive=False, show_copy_button=True)
143
+
144
+ embedding_btn = gr.Button("Embed Watermark")
145
+ with gr.Column():
146
+ marked_aud = gr.Audio(label="Output Audio", show_download_button=True)
147
+ specgram_original = gr.Plot(label="Original Audio", format="png", visible=False)
148
+ specgram_watermarked = gr.Plot(label="Watermarked Audio", format="png", visible=False)
149
+
150
+
151
+ def change_embedding_type(type):
152
+ if type == "random":
153
+ msg, _ = generate_hex_random_message(nbytes)
154
+ return gr.update(interactive=False, value=msg)
155
+ else:
156
+ return gr.update(interactive=True)
157
+ embedding_type.change(
158
+ fn=change_embedding_type,
159
+ inputs=[embedding_type],
160
+ outputs=[embedding_msg]
161
+ )
162
+
163
+ def check_embedding_msg(msg):
164
+ if not re.match(regex_pattern, msg):
165
+ gr.Warning(
166
+ f"Invalid format. Please use like '{format_like}'",
167
+ duration=0)
168
+ embedding_msg.change(
169
+ fn=check_embedding_msg,
170
+ inputs=[embedding_msg],
171
+ outputs=[]
172
+ )
173
+
174
+ def run_embed_watermark(file, show_specgram, type, msg):
175
+ if file is None:
176
+ raise gr.Erro("No file uploaded", duration=5)
177
+ if not re.match(regex_pattern, msg):
178
+ raise gr.Error(f"Invalid format. Please use like '{format_like}'", duration=5)
179
+
180
+ audio_original, rate = load_audio(file)
181
+ msg_pt = generate_msg_pt_by_format_string(msg, nbytes)
182
+ audio_watermarked = embed_watermark(audio_original, rate, msg_pt)
183
+ output = rate, audio_watermarked.squeeze().detach().cpu().numpy().astype(np.float32)
184
+
185
+ if show_specgram:
186
+ fig_original = get_waveform_and_specgram(audio_original, rate)
187
+ fig_watermarked = get_waveform_and_specgram(audio_watermarked, rate)
188
+ return [
189
+ output,
190
+ gr.update(visible=True, value=fig_original),
191
+ gr.update(visible=True, value=fig_watermarked)]
192
+ else:
193
+ return [
194
+ output,
195
+ gr.update(visible=False),
196
+ gr.update(visible=False)]
197
+
198
+ embedding_btn.click(
199
+ fn=run_embed_watermark,
200
+ inputs=[embedding_aud, embedding_specgram, embedding_type, embedding_msg],
201
+ outputs=[marked_aud, specgram_original, specgram_watermarked]
202
+ )
203
+
204
+ with gr.TabItem("Detect Watermark"):
205
+ with gr.Row():
206
+ with gr.Column():
207
+ detecting_aud = gr.Audio(label="Input Audio", type="filepath")
208
+ with gr.Column():
209
+ detecting_btn = gr.Button("Detect Watermark")
210
+ predicted_messages = gr.JSON(label="Detected Messages")
211
+
212
+ def run_detect_watermark(file):
213
+ if file is None:
214
+ raise gr.Error("No file uploaded", duration=5)
215
+
216
+ audio_watermarked, rate = load_audio(file)
217
+ result, message, pred_prob, message_prob = detect_watermark(audio_watermarked, rate)
218
+
219
+ _, fromat_msg = generate_format_string_by_msg_pt(message[0], nbytes)
220
+
221
+ sum_above_05 = (pred_prob[:, 1, :] > 0.5).sum(dim=1)
222
+
223
+ # Create message output as JSON
224
+ message_json = {
225
+ "socre": result,
226
+ "message": fromat_msg,
227
+ "frames_count_all": pred_prob.shape[2],
228
+ "frames_count_above_05": sum_above_05[0].item(),
229
+ "bits_probability": message_prob[0].tolist(),
230
+ "bits_massage": message[0].tolist()
231
+ }
232
+ return message_json
233
+ detecting_btn.click(
234
+ fn=run_detect_watermark,
235
+ inputs=[detecting_aud],
236
+ outputs=[predicted_messages]
237
+ )
238
+
239
+ if __name__ == "__main__":
240
+ demo.launch()
gradio_run.py ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ # NOTE: copy from gradio bin
2
+ import re
3
+ import sys
4
+ from gradio.cli import cli
5
+ if __name__ == '__main__':
6
+ sys.argv[0] = re.sub(r'(-script\.pyw|\.exe)?$', '', sys.argv[0])
7
+ sys.exit(cli())
requirements.txt ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ torch==2.5.1
2
+ gradio==5.8.0
3
+ huggingface-hub==0.26.3
4
+ audioseal==0.1.4
5
+ matplotlib==3.10.0
6
+ soundfile==0.12.1