xiaoyao9184 commited on
Commit
cb6069f
1 Parent(s): 4867423

Synced repo using 'sync_with_huggingface' Github Action

Browse files
Files changed (1) hide show
  1. gradio_app.py +30 -15
gradio_app.py CHANGED
@@ -20,14 +20,18 @@ import random
20
  import string
21
  from audioseal import AudioSeal
22
 
 
23
 
24
  # Load generator if not already loaded in reload mode
25
  if 'generator' not in globals():
26
  generator = AudioSeal.load_generator("audioseal_wm_16bits")
 
 
27
 
28
  # Load detector if not already loaded in reload mode
29
  if 'detector' not in globals():
30
  detector = AudioSeal.load_detector("audioseal_detector_16bits")
 
31
 
32
 
33
  def load_audio(file):
@@ -44,11 +48,15 @@ def generate_msg_pt_by_format_string(format_string, bytes_count):
44
  binary_list.append([int(b) for b in binary])
45
  # torch.randint(0, 2, (1, 16), dtype=torch.int32)
46
  msg_pt = torch.tensor(binary_list, dtype=torch.int32)
47
- return msg_pt
48
 
49
  def embed_watermark(audio, sr, msg):
50
  # We add the batch dimension to the single audio to mimic the batch watermarking
51
- original_audio = audio.unsqueeze(0)
 
 
 
 
52
 
53
  watermark = generator.get_watermark(original_audio, sr, message=msg)
54
 
@@ -73,7 +81,11 @@ def generate_format_string_by_msg_pt(msg_pt, bytes_count):
73
 
74
  def detect_watermark(audio, sr):
75
  # We add the batch dimension to the single audio to mimic the batch watermarking
76
- watermarked_audio = audio.unsqueeze(0)
 
 
 
 
77
 
78
  result, message = detector.detect_watermark(watermarked_audio, sr)
79
 
@@ -85,8 +97,12 @@ def detect_watermark(audio, sr):
85
 
86
  return result, message, pred_prob, message_prob
87
 
88
- def get_waveform_and_specgram(batch_waveform, sample_rate):
89
- waveform = batch_waveform.squeeze().detach().cpu().numpy()
 
 
 
 
90
 
91
  num_frames = waveform.shape[-1]
92
  time_axis = torch.arange(0, num_frames) / sample_rate
@@ -132,11 +148,10 @@ with gr.Blocks(title="AudioSeal") as demo:
132
 
133
  embedding_type = gr.Radio(["random", "input"], value="random", label="Type", info="Type of watermarks")
134
 
135
- nbytes = int(generator.msg_processor.nbits / 8)
136
- format_like, regex_pattern = generate_hex_format_regex(nbytes)
137
- msg, _ = generate_hex_random_message(nbytes)
138
  embedding_msg = gr.Textbox(
139
- label=f"Message ({nbytes} bytes hex string)",
140
  info=f"format like {format_like}",
141
  value=msg,
142
  interactive=False, show_copy_button=True)
@@ -150,7 +165,7 @@ with gr.Blocks(title="AudioSeal") as demo:
150
 
151
  def change_embedding_type(type):
152
  if type == "random":
153
- msg, _ = generate_hex_random_message(nbytes)
154
  return gr.update(interactive=False, value=msg)
155
  else:
156
  return gr.update(interactive=True)
@@ -178,13 +193,13 @@ with gr.Blocks(title="AudioSeal") as demo:
178
  raise gr.Error(f"Invalid format. Please use like '{format_like}'", duration=5)
179
 
180
  audio_original, rate = load_audio(file)
181
- msg_pt = generate_msg_pt_by_format_string(msg, nbytes)
182
  audio_watermarked = embed_watermark(audio_original, rate, msg_pt)
183
  output = rate, audio_watermarked.squeeze().detach().cpu().numpy().astype(np.float32)
184
 
185
  if show_specgram:
186
- fig_original = get_waveform_and_specgram(audio_original, rate)
187
- fig_watermarked = get_waveform_and_specgram(audio_watermarked, rate)
188
  return [
189
  output,
190
  gr.update(visible=True, value=fig_original),
@@ -205,8 +220,8 @@ with gr.Blocks(title="AudioSeal") as demo:
205
  with gr.Row():
206
  with gr.Column():
207
  detecting_aud = gr.Audio(label="Input Audio", type="filepath")
208
- with gr.Column():
209
  detecting_btn = gr.Button("Detect Watermark")
 
210
  predicted_messages = gr.JSON(label="Detected Messages")
211
 
212
  def run_detect_watermark(file):
@@ -216,7 +231,7 @@ with gr.Blocks(title="AudioSeal") as demo:
216
  audio_watermarked, rate = load_audio(file)
217
  result, message, pred_prob, message_prob = detect_watermark(audio_watermarked, rate)
218
 
219
- _, fromat_msg = generate_format_string_by_msg_pt(message[0], nbytes)
220
 
221
  sum_above_05 = (pred_prob[:, 1, :] > 0.5).sum(dim=1)
222
 
 
20
  import string
21
  from audioseal import AudioSeal
22
 
23
+ device = "cuda" if torch.cuda.is_available() else "cpu"
24
 
25
  # Load generator if not already loaded in reload mode
26
  if 'generator' not in globals():
27
  generator = AudioSeal.load_generator("audioseal_wm_16bits")
28
+ generator = generator.to(device)
29
+ generator_nbytes = int(generator.msg_processor.nbits / 8)
30
 
31
  # Load detector if not already loaded in reload mode
32
  if 'detector' not in globals():
33
  detector = AudioSeal.load_detector("audioseal_detector_16bits")
34
+ detector = detector.to(device)
35
 
36
 
37
  def load_audio(file):
 
48
  binary_list.append([int(b) for b in binary])
49
  # torch.randint(0, 2, (1, 16), dtype=torch.int32)
50
  msg_pt = torch.tensor(binary_list, dtype=torch.int32)
51
+ return msg_pt.to(device)
52
 
53
  def embed_watermark(audio, sr, msg):
54
  # We add the batch dimension to the single audio to mimic the batch watermarking
55
+ original_audio = audio.unsqueeze(0).to(device)
56
+
57
+ # If the audio has more than one channel, average all channels to 1 channel
58
+ if original_audio.shape[0] > 1:
59
+ original_audio = torch.mean(original_audio, dim=0, keepdim=True)
60
 
61
  watermark = generator.get_watermark(original_audio, sr, message=msg)
62
 
 
81
 
82
  def detect_watermark(audio, sr):
83
  # We add the batch dimension to the single audio to mimic the batch watermarking
84
+ watermarked_audio = audio.unsqueeze(0).to(device)
85
+
86
+ # If the audio has more than one channel, average all channels to 1 channel
87
+ if watermarked_audio.shape[0] > 1:
88
+ watermarked_audio = torch.mean(watermarked_audio, dim=0, keepdim=True)
89
 
90
  result, message = detector.detect_watermark(watermarked_audio, sr)
91
 
 
97
 
98
  return result, message, pred_prob, message_prob
99
 
100
+ def get_waveform_and_specgram(waveform, sample_rate):
101
+ # If the audio has more than one channel, average all channels to 1 channel
102
+ if waveform.shape[0] > 1:
103
+ waveform = torch.mean(waveform, dim=0, keepdim=True)
104
+
105
+ waveform = waveform.squeeze().detach().cpu().numpy()
106
 
107
  num_frames = waveform.shape[-1]
108
  time_axis = torch.arange(0, num_frames) / sample_rate
 
148
 
149
  embedding_type = gr.Radio(["random", "input"], value="random", label="Type", info="Type of watermarks")
150
 
151
+ format_like, regex_pattern = generate_hex_format_regex(generator_nbytes)
152
+ msg, _ = generate_hex_random_message(generator_nbytes)
 
153
  embedding_msg = gr.Textbox(
154
+ label=f"Message ({generator_nbytes} bytes hex string)",
155
  info=f"format like {format_like}",
156
  value=msg,
157
  interactive=False, show_copy_button=True)
 
165
 
166
  def change_embedding_type(type):
167
  if type == "random":
168
+ msg, _ = generate_hex_random_message(generator_nbytes)
169
  return gr.update(interactive=False, value=msg)
170
  else:
171
  return gr.update(interactive=True)
 
193
  raise gr.Error(f"Invalid format. Please use like '{format_like}'", duration=5)
194
 
195
  audio_original, rate = load_audio(file)
196
+ msg_pt = generate_msg_pt_by_format_string(msg, generator_nbytes)
197
  audio_watermarked = embed_watermark(audio_original, rate, msg_pt)
198
  output = rate, audio_watermarked.squeeze().detach().cpu().numpy().astype(np.float32)
199
 
200
  if show_specgram:
201
+ fig_original = get_waveform_and_specgram(audio_original.squeeze(), rate)
202
+ fig_watermarked = get_waveform_and_specgram(audio_watermarked.squeeze(), rate)
203
  return [
204
  output,
205
  gr.update(visible=True, value=fig_original),
 
220
  with gr.Row():
221
  with gr.Column():
222
  detecting_aud = gr.Audio(label="Input Audio", type="filepath")
 
223
  detecting_btn = gr.Button("Detect Watermark")
224
+ with gr.Column():
225
  predicted_messages = gr.JSON(label="Detected Messages")
226
 
227
  def run_detect_watermark(file):
 
231
  audio_watermarked, rate = load_audio(file)
232
  result, message, pred_prob, message_prob = detect_watermark(audio_watermarked, rate)
233
 
234
+ _, fromat_msg = generate_format_string_by_msg_pt(message[0], generator_nbytes)
235
 
236
  sum_above_05 = (pred_prob[:, 1, :] > 0.5).sum(dim=1)
237