|
import gradio as gr |
|
import librosa |
|
import numpy as np |
|
import torch |
|
import torchaudio |
|
import torchaudio.transforms as T |
|
import matplotlib.pyplot as plt |
|
from PIL import Image |
|
import io |
|
import tempfile |
|
import logging |
|
from audioseal import AudioSeal |
|
import random |
|
from pathlib import Path |
|
|
|
logging.basicConfig(level=logging.DEBUG, filename='app.log', filemode='w', format='%(name)s - %(levelname)s - %(message)s') |
|
logger = logging.getLogger(__name__) |
|
|
|
def generate_random_binary_message(length=16): |
|
return ''.join([str(torch.randint(0, 2, (1,)).item()) for _ in range(length)]) |
|
|
|
def load_and_resample_audio(audio_file_path, target_sample_rate=16000): |
|
waveform, sample_rate = torchaudio.load(audio_file_path) |
|
if sample_rate != target_sample_rate: |
|
resampler = T.Resample(orig_freq=sample_rate, new_freq=target_sample_rate) |
|
waveform = resampler(waveform) |
|
return waveform, target_sample_rate |
|
|
|
def plot_spectrogram_to_image(waveform, sample_rate, n_fft=400): |
|
spectrogram_transform = T.Spectrogram(n_fft=n_fft, power=2) |
|
spectrogram = spectrogram_transform(waveform) |
|
spectrogram_db = torchaudio.transforms.AmplitudeToDB()(spectrogram) |
|
plt.figure(figsize=(10, 4)) |
|
plt.imshow(spectrogram_db.detach().numpy(), cmap='hot', aspect='auto') |
|
plt.axis('off') |
|
buf = io.BytesIO() |
|
plt.savefig(buf, format='png', bbox_inches='tight', pad_inches=0) |
|
buf.seek(0) |
|
plt.close() |
|
return Image.open(buf) |
|
|
|
def plot_waveform_to_image(waveform, sample_rate): |
|
plt.figure(figsize=(10, 4)) |
|
plt.plot(waveform.detach().numpy()[0], color='black') |
|
plt.axis('off') |
|
buf = io.BytesIO() |
|
plt.savefig(buf, format='png', bbox_inches='tight', pad_inches=0) |
|
buf.seek(0) |
|
plt.close() |
|
return Image.open(buf) |
|
|
|
def watermark_audio(audio_file_path): |
|
waveform, sample_rate = load_and_resample_audio(audio_file_path, 16000) |
|
waveform = torch.clamp(waveform, min=-1.0, max=1.0) |
|
if waveform.ndim == 1: |
|
waveform = waveform.unsqueeze(0) |
|
if waveform.ndim == 2: |
|
waveform = waveform.unsqueeze(0) |
|
|
|
original_waveform_image = plot_waveform_to_image(waveform, sample_rate) |
|
original_spec_image = plot_spectrogram_to_image(waveform, sample_rate) |
|
|
|
generator = AudioSeal.load_generator("audioseal_wm_16bits") |
|
message = generate_random_binary_message() |
|
message_tensor = torch.tensor([int(bit) for bit in message], dtype=torch.int32).unsqueeze(0) |
|
watermarked_audio = generator(waveform, message=message_tensor) |
|
|
|
temp_file = tempfile.NamedTemporaryFile(delete=False, suffix='.wav') |
|
torchaudio.save(temp_file.name, watermarked_audio.squeeze(0), sample_rate) |
|
|
|
watermarked_waveform_image = plot_waveform_to_image(watermarked_audio, sample_rate) |
|
watermarked_spec_image = plot_spectrogram_to_image(watermarked_audio, sample_rate) |
|
|
|
return temp_file.name, message, original_waveform_image, original_spec_image, watermarked_waveform_image, watermarked_spec_image |
|
|
|
def detect_watermark(audio_file_path, sample_rate=16000): |
|
waveform, sample_rate = load_and_resample_audio(audio_file_path, sample_rate) |
|
detector = AudioSeal.load_detector("audioseal_detector_16bits") |
|
results, messages = detector.forward(waveform.unsqueeze(0), sample_rate=sample_rate) |
|
detect_probs = results[:, 1, :] |
|
result = detect_probs.mean().cpu().item() |
|
message = f"Detection result: {'Watermarked Audio' if result > 0.5 else 'Not watermarked'}" |
|
spectrogram_image = plot_spectrogram_to_image(waveform, sample_rate) |
|
return message, spectrogram_image |
|
|
|
style_path = Path("style.css") |
|
style = style_path.read_text() |
|
|
|
with gr.Blocks(css=style) as demo: |
|
with gr.Tab("Watermark Audio"): |
|
with gr.Column(scale=6): |
|
audio_input_watermark = gr.Audio(label="Upload Audio File for Watermarking", type="filepath") |
|
watermark_button = gr.Button("Apply Watermark") |
|
watermarked_audio_output = gr.Audio(label="Watermarked Audio") |
|
binary_message_output = gr.Textbox(label="Binary Message") |
|
original_waveform_output = gr.Image(label="Original Waveform") |
|
original_spectrogram_output = gr.Image(label="Original Spectrogram") |
|
watermarked_waveform_output = gr.Image(label="Watermarked Waveform") |
|
watermarked_spectrogram_output = gr.Image(label="Watermarked Spectrogram") |
|
watermark_button.click(fn=watermark_audio, inputs=audio_input_watermark, outputs=[watermarked_audio_output, binary_message_output, original_waveform_output, original_spectrogram_output, watermarked_waveform_output, watermarked_spectrogram_output]) |
|
|
|
with gr.Tab("Detect Watermark"): |
|
with gr.Column(scale=6): |
|
audio_input_detect_watermark = gr.Audio(label="Upload Audio File for Watermark Detection", type="filepath") |
|
detect_watermark_button = gr.Button("Detect Watermark") |
|
watermark_detection_output = gr.Textbox(label="Watermark Detection Result") |
|
spectrogram_image_output = gr.Image(label="Spectrogram") |
|
detect_watermark_button.click(fn=detect_watermark, inputs=[audio_input_detect_watermark, "16000"], outputs=[watermark_detection_output, spectrogram_image_output]) |
|
|
|
demo.launch() |
|
|