|
import gradio as gr |
|
import torch |
|
import numpy as np |
|
import librosa |
|
import os |
|
from transformers import Wav2Vec2BertModel, AutoFeatureExtractor, HubertModel |
|
import torch.nn as nn |
|
from typing import Optional, Tuple |
|
from transformers.file_utils import ModelOutput |
|
from dataclasses import dataclass |
|
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss |
|
|
|
@dataclass |
|
class SpeechClassifierOutput(ModelOutput): |
|
loss: Optional[torch.FloatTensor] = None |
|
logits: torch.FloatTensor = None |
|
hidden_states: Optional[Tuple[torch.FloatTensor]] = None |
|
attentions: Optional[Tuple[torch.FloatTensor]] = None |
|
|
|
from transformers.models.wav2vec2.modeling_wav2vec2 import ( |
|
Wav2Vec2PreTrainedModel, |
|
Wav2Vec2Model |
|
) |
|
class Wav2Vec2ClassificationHead(nn.Module): |
|
"""Head for wav2vec classification task.""" |
|
|
|
def __init__(self, config): |
|
super().__init__() |
|
self.dense = nn.Linear(config.hidden_size, config.hidden_size) |
|
self.dropout = nn.Dropout(config.final_dropout) |
|
self.out_proj = nn.Linear(config.hidden_size, config.num_labels) |
|
|
|
def forward(self, features, **kwargs): |
|
x = features |
|
x = self.dropout(x) |
|
x = self.dense(x) |
|
x = torch.tanh(x) |
|
x = self.dropout(x) |
|
x = self.out_proj(x) |
|
return x |
|
|
|
|
|
class Wav2Vec2ForSpeechClassification(nn.Module): |
|
def __init__(self,model_name): |
|
super().__init__() |
|
self.num_labels = 2 |
|
self.pooling_mode = 'mean' |
|
self.wav2vec2bert = Wav2Vec2BertModel.from_pretrained(model_name) |
|
self.config = self.wav2vec2bert.config |
|
self.classifier = Wav2Vec2ClassificationHead(self.wav2vec2bert.config) |
|
|
|
def merged_strategy(self,hidden_states,mode="mean"): |
|
if mode == "mean": |
|
outputs = torch.mean(hidden_states, dim=1) |
|
elif mode == "sum": |
|
outputs = torch.sum(hidden_states, dim=1) |
|
elif mode == "max": |
|
outputs = torch.max(hidden_states, dim=1)[0] |
|
else: |
|
raise Exception( |
|
"The pooling method hasn't been defined! Your pooling mode must be one of these ['mean', 'sum', 'max']") |
|
|
|
return outputs |
|
|
|
def forward(self,input_features,attention_mask=None,output_attentions=None,output_hidden_states=None,return_dict=None,labels=None,): |
|
return_dict = return_dict if return_dict is not None else self.config.use_return_dict |
|
outputs = self.wav2vec2bert( |
|
input_features, |
|
attention_mask=attention_mask, |
|
output_attentions=output_attentions, |
|
output_hidden_states=output_hidden_states, |
|
return_dict=return_dict, |
|
) |
|
hidden_states = outputs.last_hidden_state |
|
hidden_states = self.merged_strategy(hidden_states, mode=self.pooling_mode) |
|
logits = self.classifier(hidden_states) |
|
|
|
loss = None |
|
if labels is not None: |
|
if self.config.problem_type is None: |
|
if self.num_labels == 1: |
|
self.config.problem_type = "regression" |
|
elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int): |
|
self.config.problem_type = "single_label_classification" |
|
else: |
|
self.config.problem_type = "multi_label_classification" |
|
|
|
if self.config.problem_type == "regression": |
|
loss_fct = MSELoss() |
|
loss = loss_fct(logits.view(-1, self.num_labels), labels) |
|
elif self.config.problem_type == "single_label_classification": |
|
loss_fct = CrossEntropyLoss() |
|
loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1)) |
|
elif self.config.problem_type == "multi_label_classification": |
|
loss_fct = BCEWithLogitsLoss() |
|
loss = loss_fct(logits, labels) |
|
|
|
if not return_dict: |
|
output = (logits,) + outputs[2:] |
|
return ((loss,) + output) if loss is not None else output |
|
|
|
return SpeechClassifierOutput( |
|
loss=loss, |
|
logits=logits, |
|
hidden_states=outputs.last_hidden_state, |
|
attentions=outputs.attentions, |
|
) |
|
|
|
class HuBERT(nn.Module): |
|
def __init__(self, model_name): |
|
super().__init__() |
|
self.num_labels = 2 |
|
self.pooling_mode = 'mean' |
|
self.wav2vec2 = HubertModel.from_pretrained(model_name) |
|
self.config = self.wav2vec2.config |
|
self.classifier = Wav2Vec2ClassificationHead(self.wav2vec2.config) |
|
|
|
def merged_strategy(self, hidden_states, mode="mean"): |
|
if mode == "mean": |
|
outputs = torch.mean(hidden_states, dim=1) |
|
elif mode == "sum": |
|
outputs = torch.sum(hidden_states, dim=1) |
|
elif mode == "max": |
|
outputs = torch.max(hidden_states, dim=1)[0] |
|
else: |
|
raise Exception( |
|
"The pooling method hasn't been defined! Your pooling mode must be one of these ['mean', 'sum', 'max']") |
|
|
|
return outputs |
|
|
|
def forward(self, input_values, attention_mask=None, output_attentions=None, output_hidden_states=None, |
|
return_dict=None, labels=None, ): |
|
return_dict = return_dict if return_dict is not None else self.config.use_return_dict |
|
outputs = self.wav2vec2( |
|
input_values, |
|
attention_mask=attention_mask, |
|
output_attentions=output_attentions, |
|
output_hidden_states=output_hidden_states, |
|
return_dict=return_dict, |
|
) |
|
hidden_states = outputs.last_hidden_state |
|
hidden_states = self.merged_strategy(hidden_states, mode=self.pooling_mode) |
|
logits = self.classifier(hidden_states) |
|
|
|
loss = None |
|
if labels is not None: |
|
if self.config.problem_type is None: |
|
if self.num_labels == 1: |
|
self.config.problem_type = "regression" |
|
elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int): |
|
self.config.problem_type = "single_label_classification" |
|
else: |
|
self.config.problem_type = "multi_label_classification" |
|
|
|
if self.config.problem_type == "regression": |
|
loss_fct = MSELoss() |
|
loss = loss_fct(logits.view(-1, self.num_labels), labels) |
|
elif self.config.problem_type == "single_label_classification": |
|
loss_fct = CrossEntropyLoss() |
|
loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1)) |
|
elif self.config.problem_type == "multi_label_classification": |
|
loss_fct = BCEWithLogitsLoss() |
|
loss = loss_fct(logits, labels) |
|
|
|
if not return_dict: |
|
output = (logits,) + outputs[2:] |
|
return ((loss,) + output) if loss is not None else output |
|
|
|
return SpeechClassifierOutput( |
|
loss=loss, |
|
logits=logits, |
|
hidden_states=outputs.last_hidden_state, |
|
attentions=outputs.attentions, |
|
) |
|
|
|
def pad(x, max_len=64000): |
|
x_len = x.shape[0] |
|
if x_len > max_len: |
|
stt = np.random.randint(x_len - max_len) |
|
return x[stt:stt + max_len] |
|
|
|
|
|
|
|
|
|
pad_length = max_len - x_len |
|
padded_x = np.concatenate([x, np.zeros(pad_length)], axis=0) |
|
return padded_x |
|
|
|
class AudioDeepfakeDetector: |
|
def __init__(self): |
|
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') |
|
self.models = {} |
|
self.feature_extractors = {} |
|
self.current_model = None |
|
|
|
|
|
|
|
|
|
|
|
|
|
print(f"Using device: {self.device}") |
|
print("Audio deepfake detector initilized") |
|
|
|
|
|
def load_model(self, model_type): |
|
"""Load the specified model type""" |
|
if model_type in self.models: |
|
self.current_model = model_type |
|
return |
|
|
|
try: |
|
print(f"π Loading {model_type} model...") |
|
|
|
if model_type == "Wave2Vec2BERT": |
|
model_name = 'facebook/w2v-bert-2.0' |
|
self.feature_extractors[model_type] = AutoFeatureExtractor.from_pretrained(model_name) |
|
self.models[model_type] = Wav2Vec2ForSpeechClassification(model_name).to(self.device) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
try: |
|
from huggingface_hub import hf_hub_download |
|
checkpoint_path = hf_hub_download( |
|
repo_id="TrustSafeAI/AudioDeepfakeDetectors", |
|
filename="wave2vec2bert_wavefake.pth", |
|
cache_dir="./models" |
|
) |
|
ckpt = torch.load(checkpoint_path, map_location=self.device) |
|
self.models[model_type].load_state_dict(ckpt) |
|
print(f"β
Loaded checkpoint for {model_type}") |
|
except Exception as e: |
|
print(f"β οΈ Could not load checkpoint for {model_type}: {e}") |
|
print("Using pretrained weights only") |
|
|
|
elif model_type == "HuBERT": |
|
model_name = 'facebook/hubert-large-ls960-ft' |
|
self.feature_extractors[model_type] = AutoFeatureExtractor.from_pretrained(model_name) |
|
self.models[model_type] = HuBERT(model_name).to(self.device) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
try: |
|
from huggingface_hub import hf_hub_download |
|
checkpoint_path = hf_hub_download( |
|
repo_id="TrustSafeAI/AudioDeepfakeDetectors", |
|
filename="hubert_large_wavefake.pth", |
|
cache_dir="./models" |
|
) |
|
ckpt = torch.load(checkpoint_path, map_location=self.device) |
|
self.models[model_type].load_state_dict(ckpt) |
|
print(f"β
Loaded checkpoint for {model_type}") |
|
except Exception as e: |
|
print(f"β οΈ Could not load checkpoint for {model_type}: {e}") |
|
print("Using pretrained weights only") |
|
|
|
self.current_model = model_type |
|
print(f"β
{model_type} model loaded successfully") |
|
|
|
except Exception as e: |
|
print(f"β Error loading {model_type} model: {str(e)}") |
|
raise |
|
|
|
def preprocess_audio(self, audio_path, target_sr=16000, max_length=4): |
|
try: |
|
print(f"π Loading audio file: {os.path.basename(audio_path)}") |
|
|
|
audio, sr = librosa.load(audio_path, sr=target_sr) |
|
original_duration = len(audio) / sr |
|
|
|
audio = pad(audio).reshape(-1) |
|
audio = audio[np.newaxis, :] |
|
|
|
|
|
print(f"β
Audio loaded successfully: {original_duration:.2f}s, {sr}Hz") |
|
return audio, sr |
|
|
|
except Exception as e: |
|
print(f"β Audio processing error: {str(e)}") |
|
raise |
|
|
|
def extract_features(self, audio, sr, model_type): |
|
print("π extract audio features...") |
|
feature_extractor = self.feature_extractors[model_type] |
|
|
|
inputs = feature_extractor(audio, sampling_rate=sr, return_attention_mask=True, padding_value=0, return_tensors="pt").to(self.device) |
|
print("β
Feature extracion completed") |
|
return inputs |
|
|
|
def classifier(self, features, model_type): |
|
model = self.models[model_type] |
|
with torch.no_grad(): |
|
outputs = model(**features) |
|
prob = outputs.logits.softmax(dim=-1) |
|
fake_prob = prob[0][0].item() |
|
|
|
return fake_prob |
|
|
|
def predict(self, audio_path, model_type): |
|
try: |
|
print("π΅ Start analyzing...") |
|
self.load_model(model_type) |
|
audio, sr = self.preprocess_audio(audio_path) |
|
|
|
features= self.extract_features(audio, sr, model_type) |
|
|
|
fake_probability = self.classifier(features, model_type) |
|
real_probability = 1 - fake_probability |
|
|
|
threshold = 0.5 |
|
if fake_probability > threshold: |
|
status = "SUSPICIOUS" |
|
prediction = "π¨ Likely fake audio" |
|
confidence = fake_probability |
|
color = "red" |
|
else: |
|
status = "LIKELY_REAL" |
|
prediction = "β
Likely real audio" |
|
confidence = real_probability |
|
color = "green" |
|
|
|
print(f"\n{'='*50}") |
|
print(f"π― Result: {prediction}") |
|
print(f"π Confidence: {confidence:.1%}") |
|
print(f"π Real Probability: {real_probability:.1%}") |
|
print(f"π Fake Probability: {fake_probability:.1%}") |
|
print(f"{'='*50}") |
|
|
|
duration = len(audio) / sr |
|
file_size = os.path.getsize(audio_path) / 1024 |
|
|
|
result_data = { |
|
"status": status, |
|
"prediction": prediction, |
|
"confidence": confidence, |
|
"real_probability": real_probability, |
|
"fake_probability": fake_probability, |
|
"duration": duration, |
|
"sample_rate": sr, |
|
"file_size_kb": file_size, |
|
"model_used": model_type |
|
} |
|
|
|
return result_data |
|
|
|
except Exception as e: |
|
print(f"β Failed: {str(e)}") |
|
return {"error": str(e)} |
|
|
|
|
|
detector = AudioDeepfakeDetector() |
|
|
|
def analyze_uploaded_audio(audio_file, model_choice): |
|
if audio_file is None: |
|
return "Please upload audio", {} |
|
|
|
try: |
|
result = detector.predict(audio_file, model_choice) |
|
|
|
if "error" in result: |
|
return f"Error: {result['error']}", {} |
|
|
|
status_color = "#ff4444" if result['status'] == "SUSPICIOUS" else "#44ff44" |
|
|
|
result_html = f""" |
|
<div style="padding: 20px; border-radius: 10px; background-color: {status_color}20; border: 2px solid {status_color};"> |
|
<h3 style="color: {status_color}; margin-top: 0;">{result['prediction']}</h3> |
|
<p><strong>Status:</strong> {result['status']}</p> |
|
<p><strong>Confidence:</strong> {result['confidence']:.1%}</p> |
|
</div> |
|
""" |
|
|
|
analysis_data = { |
|
"status": result['status'], |
|
"real_probability": f"{result['real_probability']:.1%}", |
|
"fake_probability": f"{result['fake_probability']:.1%}", |
|
} |
|
|
|
return result_html, analysis_data |
|
|
|
except Exception as e: |
|
error_html = f""" |
|
<div style="padding: 20px; border-radius: 10px; background-color: #ff444420; border: 2px solid #ff4444;"> |
|
<h3 style="color: #ff4444;">β Processing error</h3> |
|
<p>{str(e)}</p> |
|
</div> |
|
""" |
|
return error_html, {"error": str(e)} |
|
|
|
def create_audio_interface(): |
|
with gr.Blocks(title="Audio Deepfake Detection", theme=gr.themes.Soft()) as interface: |
|
gr.Markdown(""" |
|
<div style="text-align: center; margin-bottom: 30px;"> |
|
<h1 style="font-size: 28px; font-weight: bold; margin-bottom: 20px; color: #333;"> |
|
Measuring the Robustness of Audio Deepfake Detection under Real-World Corruptions |
|
</h1> |
|
<p style="font-size: 16px; color: #666; margin-bottom: 15px;"> |
|
Audio deepfake detectors based on Wave2Vec2BERT and HuBERT speech foundation models (fine-tuned with Wavefake dataset). |
|
</p> |
|
<div style="font-size: 14px; color: #555; line-height: 1.8; text-align: left;"> |
|
<p><strong>Paper:</strong> <a href="https://arxiv.org/pdf/2503.17577" target="_blank" style="color: #4285f4; text-decoration: none;">https://arxiv.org/pdf/2503.17577</a></p> |
|
<p><strong>Project Page:</strong> <a href="https://huggingface.co/spaces/TrustSafeAI/AudioPerturber" target="_blank" style="color: #4285f4; text-decoration: none;">"https://huggingface.co/spaces/TrustSafeAI/AudioPerturber</a></p> |
|
<p><strong>Model Checkpoints:</strong> <a href="https://huggingface.co/TrustSafeAI/AudioDeepfakeDetectors" target="_blank" style="color: #4285f4; text-decoration: none;">"https://huggingface.co/TrustSafeAI/AudioDeepfakeDetectors</a></p> |
|
<p><strong>Github Codebase:</strong> <a href="https://github.com/Jessegator/Audio_robustness_evaluation" target="_blank" style="color: #4285f4; text-decoration: none;">https://github.com/Jessegator/Audio_robustness_evaluation</a></p> |
|
</div> |
|
</div> |
|
<hr style="margin: 30px 0; border: none; border-top: 1px solid #e0e0e0;"> |
|
""") |
|
|
|
gr.Markdown(""" |
|
# Audio Deepfake Detection |
|
|
|
**Supported Format**: .wav, .mp3, .flac, .m4a, etc. |
|
""") |
|
|
|
with gr.Row(): |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
with gr.Column(scale=1): |
|
model_choice = gr.Dropdown( |
|
choices=["Wave2Vec2BERT", "HuBERT"], |
|
value="Wave2Vec2BERT", |
|
label="π€ Select Model", |
|
info="Choose the foundation model for detection" |
|
) |
|
|
|
audio_input = gr.Audio( |
|
label="π Upload audio file", |
|
type="filepath", |
|
show_label=True, |
|
interactive=True |
|
) |
|
|
|
analyze_btn = gr.Button( |
|
"π Start analyzing", |
|
variant="primary", |
|
size="lg" |
|
) |
|
|
|
gr.Markdown("### π Play uploaded audio") |
|
audio_player = gr.Audio( |
|
label="Audio Player", |
|
interactive=False, |
|
show_label=False |
|
) |
|
|
|
with gr.Column(scale=1): |
|
result_display = gr.HTML( |
|
label="π― Results", |
|
value="<p style='text-align: center; color: #666;'>Waiting for uploading...</p>" |
|
) |
|
|
|
analysis_json = gr.JSON( |
|
label="π Detailed analysis", |
|
value={} |
|
) |
|
|
|
def update_player_and_analyze(audio_file, model_type): |
|
if audio_file is not None: |
|
result_html, result_data = analyze_uploaded_audio(audio_file, model_type) |
|
return audio_file, result_html, result_data |
|
else: |
|
return None, "<p style='text-align: center; color: #666;'>Waiting for uploading...</p>", {} |
|
|
|
audio_input.change( |
|
fn=update_player_and_analyze, |
|
inputs=[audio_input, model_choice], |
|
outputs=[audio_player, result_display, analysis_json] |
|
) |
|
|
|
analyze_btn.click( |
|
fn=analyze_uploaded_audio, |
|
inputs=[audio_input, model_choice], |
|
outputs=[result_display, analysis_json] |
|
) |
|
|
|
model_choice.change( |
|
fn=lambda audio_file, model_type: analyze_uploaded_audio(audio_file, model_type) if audio_file is not None else ("Please upload audio first", {}), |
|
inputs=[audio_input, model_choice], |
|
outputs=[result_display, analysis_json] |
|
) |
|
|
|
return interface |
|
|
|
if __name__ == "__main__": |
|
print("π Create interface...") |
|
demo = create_audio_interface() |
|
|
|
print("π± Launching...") |
|
demo.launch( |
|
share=False, |
|
debug=True, |
|
show_error=True |
|
) |