File size: 4,839 Bytes
fe0bcff dfabd2f 30a5efb c7ab812 fe0bcff c7ab812 f0dd070 c7ab812 fe0bcff 84de51b ee91d94 c7ab812 01ce6f2 c7ab812 84de51b 0c35856 fe0bcff 411539a c7ab812 50facbf 411539a c7ab812 84de51b fe0bcff 84de51b a29043b fe0bcff 84de51b 53b1abc fe0bcff 84de51b fe0bcff c7ab812 fe0bcff c7ab812 93984af c7ab812 93984af c7ab812 e8e81bf 93984af ee91d94 8676909 01ce6f2 8676909 01ce6f2 fe0bcff 01ce6f2 aaa09e2 01ce6f2 6781020 01ce6f2 8676909 01ce6f2 6781020 8676909 aaa09e2 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 |
import gradio as gr
import librosa
import numpy as np
import torch
import matplotlib.pyplot as plt
from transformers import AutoModelForAudioClassification, ASTFeatureExtractor, Wav2Vec2Processor, Wav2Vec2ForCTC
import random
import tempfile
# Load Wav2Vec 2.0 models
wav2vec_processor = Wav2Vec2Processor.from_pretrained("facebook/wav2vec2-base-960h")
wav2vec_model = Wav2Vec2ForCTC.from_pretrained("facebook/wav2vec2-base-960h")
# Original model and feature extractor loading
model = AutoModelForAudioClassification.from_pretrained("./")
feature_extractor = ASTFeatureExtractor.from_pretrained("./")
def plot_waveform(waveform, sr):
plt.figure(figsize=(12, 4))
plt.title('Waveform')
plt.ylabel('Amplitude')
plt.plot(np.linspace(0, len(waveform) / sr, len(waveform)), waveform)
plt.xlabel('Time (s)')
temp_file = tempfile.NamedTemporaryFile(delete=False, suffix='.png', dir='./')
plt.savefig(temp_file.name)
plt.close()
return temp_file.name
def plot_spectrogram(waveform, sr):
S = librosa.feature.melspectrogram(y=waveform, sr=sr, n_mels=128)
S_DB = librosa.power_to_db(S, ref=np.max)
plt.figure(figsize=(12, 6))
librosa.display.specshow(S_DB, sr=sr, x_axis='time', y_axis='mel', cmap='inferno')
plt.title('Mel Spectrogram')
plt.colorbar(format='%+2.0f dB')
plt.tight_layout()
temp_file = tempfile.NamedTemporaryFile(delete=False, suffix='.png', dir='./')
plt.savefig(temp_file.name)
plt.close()
return temp_file.name
def custom_feature_extraction(audio, sr=16000, target_length=1024):
features = feature_extractor(audio, sampling_rate=sr, return_tensors="pt", padding="max_length", max_length=target_length)
return features.input_values
def apply_time_shift(waveform, max_shift_fraction=0.1):
shift = random.randint(-int(max_shift_fraction * len(waveform)), int(max_shift_fraction * len(waveform)))
return np.roll(waveform, shift)
def transcribe_audio(audio_file_path):
waveform, _ = librosa.load(audio_file_path, sr=wav2vec_processor.feature_extractor.sampling_rate, mono=True)
input_values = wav2vec_processor(waveform, return_tensors="pt", padding="longest").input_values
with torch.no_grad():
logits = wav2vec_model(input_values).logits
predicted_ids = torch.argmax(logits, dim=-1)
transcription = wav2vec_processor.batch_decode(predicted_ids)
return transcription
def predict_voice(audio_file_path):
try:
transcription = transcribe_audio(audio_file_path)
waveform, sample_rate = librosa.load(audio_file_path, sr=feature_extractor.sampling_rate, mono=True)
augmented_waveform = apply_time_shift(waveform)
original_features = custom_feature_extraction(waveform, sr=sample_rate)
augmented_features = custom_feature_extraction(augmented_waveform, sr=sample_rate)
with torch.no_grad():
outputs_original = model(original_features)
outputs_augmented = model(augmented_features)
logits = (outputs_original.logits + outputs_augmented.logits) / 2
predicted_index = logits.argmax()
original_label = model.config.id2label[predicted_index.item()]
confidence = torch.softmax(logits, dim=1).max().item() * 100
label_mapping = {
"Spoof": "AI-generated Clone",
"Bonafide": "Real Human Voice"
}
new_label = label_mapping.get(original_label, "Unknown")
waveform_plot = plot_waveform(waveform, sample_rate)
spectrogram_plot = plot_spectrogram(waveform, sample_rate)
return (
f"The voice is classified as '{new_label}' with a confidence of {confidence:.2f}%.",
waveform_plot,
spectrogram_plot,
transcription[0] # Assuming transcription returns a list with a single string
)
except Exception as e:
return f"Error during processing: {e}", None, None, ""
with gr.Blocks(css="style.css") as demo:
gr.Markdown("## Voice Clone Detection")
gr.Markdown("Detects whether a voice is real or an AI-generated clone. Upload an audio file to see the results.")
with gr.Row():
audio_input = gr.Audio(label="Upload Audio File", type="filepath")
with gr.Row():
prediction_output = gr.Textbox(label="Prediction")
transcription_output = gr.Textbox(label="Transcription") # Fixed indentation
waveform_output = gr.Image(label="Waveform")
spectrogram_output = gr.Image(label="Spectrogram")
detect_button = gr.Button("Detect Voice Clone")
detect_button.click(
fn=predict_voice,
inputs=[audio_input],
outputs=[prediction_output, waveform_output, spectrogram_output, transcription_output]
)
# Launch the interface
demo.launch() |