|
import gradio as gr |
|
import librosa |
|
import numpy as np |
|
import torch |
|
import matplotlib.pyplot as plt |
|
from transformers import AutoModelForAudioClassification, ASTFeatureExtractor, Wav2Vec2Processor, Wav2Vec2ForCTC |
|
import random |
|
import tempfile |
|
|
|
|
|
wav2vec_processor = Wav2Vec2Processor.from_pretrained("facebook/wav2vec2-base-960h") |
|
wav2vec_model = Wav2Vec2ForCTC.from_pretrained("facebook/wav2vec2-base-960h") |
|
|
|
|
|
model = AutoModelForAudioClassification.from_pretrained("./") |
|
feature_extractor = ASTFeatureExtractor.from_pretrained("./") |
|
|
|
def plot_waveform(waveform, sr): |
|
plt.figure(figsize=(12, 4)) |
|
plt.title('Waveform') |
|
plt.ylabel('Amplitude') |
|
plt.plot(np.linspace(0, len(waveform) / sr, len(waveform)), waveform) |
|
plt.xlabel('Time (s)') |
|
temp_file = tempfile.NamedTemporaryFile(delete=False, suffix='.png', dir='./') |
|
plt.savefig(temp_file.name) |
|
plt.close() |
|
return temp_file.name |
|
|
|
def plot_spectrogram(waveform, sr): |
|
S = librosa.feature.melspectrogram(y=waveform, sr=sr, n_mels=128) |
|
S_DB = librosa.power_to_db(S, ref=np.max) |
|
plt.figure(figsize=(12, 6)) |
|
librosa.display.specshow(S_DB, sr=sr, x_axis='time', y_axis='mel', cmap='inferno') |
|
plt.title('Mel Spectrogram') |
|
plt.colorbar(format='%+2.0f dB') |
|
plt.tight_layout() |
|
temp_file = tempfile.NamedTemporaryFile(delete=False, suffix='.png', dir='./') |
|
plt.savefig(temp_file.name) |
|
plt.close() |
|
return temp_file.name |
|
|
|
def custom_feature_extraction(audio, sr=16000, target_length=1024): |
|
features = feature_extractor(audio, sampling_rate=sr, return_tensors="pt", padding="max_length", max_length=target_length) |
|
return features.input_values |
|
|
|
def apply_time_shift(waveform, max_shift_fraction=0.1): |
|
shift = random.randint(-int(max_shift_fraction * len(waveform)), int(max_shift_fraction * len(waveform))) |
|
return np.roll(waveform, shift) |
|
|
|
def transcribe_audio(audio_file_path): |
|
waveform, _ = librosa.load(audio_file_path, sr=wav2vec_processor.feature_extractor.sampling_rate, mono=True) |
|
input_values = wav2vec_processor(waveform, return_tensors="pt", padding="longest").input_values |
|
with torch.no_grad(): |
|
logits = wav2vec_model(input_values).logits |
|
predicted_ids = torch.argmax(logits, dim=-1) |
|
transcription = wav2vec_processor.batch_decode(predicted_ids) |
|
return transcription |
|
|
|
def predict_voice(audio_file_path): |
|
try: |
|
transcription = transcribe_audio(audio_file_path) |
|
|
|
waveform, sample_rate = librosa.load(audio_file_path, sr=feature_extractor.sampling_rate, mono=True) |
|
augmented_waveform = apply_time_shift(waveform) |
|
|
|
original_features = custom_feature_extraction(waveform, sr=sample_rate) |
|
augmented_features = custom_feature_extraction(augmented_waveform, sr=sample_rate) |
|
|
|
with torch.no_grad(): |
|
outputs_original = model(original_features) |
|
outputs_augmented = model(augmented_features) |
|
|
|
logits = (outputs_original.logits + outputs_augmented.logits) / 2 |
|
predicted_index = logits.argmax() |
|
original_label = model.config.id2label[predicted_index.item()] |
|
confidence = torch.softmax(logits, dim=1).max().item() * 100 |
|
|
|
label_mapping = { |
|
"Spoof": "AI-generated Clone", |
|
"Bonafide": "Real Human Voice" |
|
} |
|
new_label = label_mapping.get(original_label, "Unknown") |
|
|
|
waveform_plot = plot_waveform(waveform, sample_rate) |
|
spectrogram_plot = plot_spectrogram(waveform, sample_rate) |
|
|
|
return ( |
|
f"The voice is classified as '{new_label}' with a confidence of {confidence:.2f}%.", |
|
waveform_plot, |
|
spectrogram_plot, |
|
transcription[0] |
|
) |
|
except Exception as e: |
|
return f"Error during processing: {e}", None, None, "" |
|
|
|
with gr.Blocks(css="style.css") as demo: |
|
gr.Markdown("## Voice Clone Detection") |
|
gr.Markdown("Detects whether a voice is real or an AI-generated clone. Upload an audio file to see the results.") |
|
|
|
with gr.Row(): |
|
audio_input = gr.Audio(label="Upload Audio File", type="filepath") |
|
|
|
with gr.Row(): |
|
prediction_output = gr.Textbox(label="Prediction") |
|
transcription_output = gr.Textbox(label="Transcription") |
|
waveform_output = gr.Image(label="Waveform") |
|
spectrogram_output = gr.Image(label="Spectrogram") |
|
|
|
detect_button = gr.Button("Detect Voice Clone") |
|
detect_button.click( |
|
fn=predict_voice, |
|
inputs=[audio_input], |
|
outputs=[prediction_output, waveform_output, spectrogram_output, transcription_output] |
|
) |
|
|
|
|
|
demo.launch() |