File size: 4,097 Bytes
fe0bcff dfabd2f 30a5efb c7ab812 9ec21ae fe0bcff c7ab812 cbd878e 9ec21ae cbd878e f0dd070 fe0bcff 84de51b ee91d94 c7ab812 9ec21ae c7ab812 9ec21ae c7ab812 84de51b 0c35856 fe0bcff 411539a 50facbf 9ec21ae c7ab812 9ec21ae ee91d94 8676909 01ce6f2 9ec21ae 01ce6f2 9ec21ae 01ce6f2 9ec21ae 6781020 9ec21ae |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 |
import gradio as gr
import librosa
import numpy as np
import torch
import matplotlib.pyplot as plt
from transformers import AutoModelForAudioClassification, ASTFeatureExtractor
import random
import tempfile
import logging
import os
logging.basicConfig(level=logging.DEBUG, filename='app.log', filemode='w', format='%(name)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)
model = AutoModelForAudioClassification.from_pretrained("./")
feature_extractor = ASTFeatureExtractor.from_pretrained("./")
def plot_waveform(waveform, sr):
plt.figure(figsize=(24, 8)) # Doubled size for larger visuals
plt.title('Waveform')
plt.ylabel('Amplitude')
plt.plot(np.linspace(0, len(waveform) / sr, len(waveform)), waveform)
plt.xlabel('Time (s)')
temp_file = tempfile.NamedTemporaryFile(delete=False, suffix='.png', dir='./')
plt.savefig(temp_file.name)
plt.close()
logger.debug(f"Waveform image generated: {temp_file.name}, Size: {os.path.getsize(temp_file.name)} bytes")
return temp_file.name
def plot_spectrogram(waveform, sr):
S = librosa.feature.melspectrogram(y=waveform, sr=sr, n_mels=128)
S_DB = librosa.power_to_db(S, ref=np.max)
plt.figure(figsize=(24, 12)) # Doubled size for larger visuals
librosa.display.specshow(S_DB, sr=sr, x_axis='time', y_axis='mel')
plt.title('Mel Spectrogram')
plt.colorbar(format='%+2.0f dB')
plt.tight_layout()
temp_file = tempfile.NamedTemporaryFile(delete=False, suffix='.png', dir='./')
plt.savefig(temp_file.name)
plt.close()
logger.debug(f"Spectrogram image generated: {temp_file.name}, Size: {os.path.getsize(temp_file.name)} bytes")
return temp_file.name
def custom_feature_extraction(audio, sr=16000, target_length=1024):
features = feature_extractor(audio, sampling_rate=sr, return_tensors="pt", padding="max_length", max_length=target_length)
return features.input_values
def apply_time_shift(waveform, max_shift_fraction=0.1):
shift = random.randint(-int(max_shift_fraction * len(waveform)), int(max_shift_fraction * len(waveform)))
return np.roll(waveform, shift)
def predict_voice(audio_file_path):
waveform, _ = librosa.load(audio_file_path, sr=16000, mono=True) # Ensure all audio is resampled to 16kHz
augmented_waveform = apply_time_shift(waveform)
original_features = custom_feature_extraction(waveform, sr=16000) # Adjusted sample rate to 16kHz
augmented_features = custom_feature_extraction(augmented_waveform, sr=16000) # Adjusted sample rate to 16kHz
with torch.no_grad():
outputs_original = model(original_features)
outputs_augmented = model(augmented_features)
logits = (outputs_original.logits + outputs_augmented.logits) / 2
predicted_index = logits.argmax()
original_label = model.config.id2label[predicted_index.item()]
confidence = torch.softmax(logits, dim=1).max().item() * 100
label_mapping = {"Spoof": "AI-generated Clone", "Bonafide": "Real Human Voice"}
new_label = label_mapping.get(original_label, "Unknown")
waveform_plot = plot_waveform(waveform, 16000) # Adjusted sample rate to 16kHz
spectrogram_plot = plot_spectrogram(waveform, 16000) # Adjusted sample rate to 16kHz
return (f"The voice is classified as '{new_label}' with a confidence of {confidence:.2f}%.",
waveform_plot,
spectrogram_plot)
with gr.Blocks(css="style.css") as demo:
gr.Markdown("## Voice Clone Detection")
gr.Markdown("Detects whether a voice is real or an AI-generated clone. Upload an audio file to see the results.")
with gr.Row():
audio_input = gr.Audio(label="Upload Audio File", type="filepath")
detect_button = gr.Button("Detect Voice Clone")
with gr.Row():
prediction_output = gr.Textbox(label="Prediction")
with gr.Row():
waveform_output = gr.Image(label="Waveform")
spectrogram_output = gr.Image(label="Spectrogram")
detect_button.click(fn=predict_voice, inputs=[audio_input], outputs=[prediction_output, waveform_output, spectrogram_output])
demo.launch()
|