Kabatubare's picture
Update app.py
9ec21ae verified
import gradio as gr
import librosa
import numpy as np
import torch
import matplotlib.pyplot as plt
from transformers import AutoModelForAudioClassification, ASTFeatureExtractor
import random
import tempfile
import logging
import os
logging.basicConfig(level=logging.DEBUG, filename='app.log', filemode='w', format='%(name)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)
model = AutoModelForAudioClassification.from_pretrained("./")
feature_extractor = ASTFeatureExtractor.from_pretrained("./")
def plot_waveform(waveform, sr):
plt.figure(figsize=(24, 8)) # Doubled size for larger visuals
plt.title('Waveform')
plt.ylabel('Amplitude')
plt.plot(np.linspace(0, len(waveform) / sr, len(waveform)), waveform)
plt.xlabel('Time (s)')
temp_file = tempfile.NamedTemporaryFile(delete=False, suffix='.png', dir='./')
plt.savefig(temp_file.name)
plt.close()
logger.debug(f"Waveform image generated: {temp_file.name}, Size: {os.path.getsize(temp_file.name)} bytes")
return temp_file.name
def plot_spectrogram(waveform, sr):
S = librosa.feature.melspectrogram(y=waveform, sr=sr, n_mels=128)
S_DB = librosa.power_to_db(S, ref=np.max)
plt.figure(figsize=(24, 12)) # Doubled size for larger visuals
librosa.display.specshow(S_DB, sr=sr, x_axis='time', y_axis='mel')
plt.title('Mel Spectrogram')
plt.colorbar(format='%+2.0f dB')
plt.tight_layout()
temp_file = tempfile.NamedTemporaryFile(delete=False, suffix='.png', dir='./')
plt.savefig(temp_file.name)
plt.close()
logger.debug(f"Spectrogram image generated: {temp_file.name}, Size: {os.path.getsize(temp_file.name)} bytes")
return temp_file.name
def custom_feature_extraction(audio, sr=16000, target_length=1024):
features = feature_extractor(audio, sampling_rate=sr, return_tensors="pt", padding="max_length", max_length=target_length)
return features.input_values
def apply_time_shift(waveform, max_shift_fraction=0.1):
shift = random.randint(-int(max_shift_fraction * len(waveform)), int(max_shift_fraction * len(waveform)))
return np.roll(waveform, shift)
def predict_voice(audio_file_path):
waveform, _ = librosa.load(audio_file_path, sr=16000, mono=True) # Ensure all audio is resampled to 16kHz
augmented_waveform = apply_time_shift(waveform)
original_features = custom_feature_extraction(waveform, sr=16000) # Adjusted sample rate to 16kHz
augmented_features = custom_feature_extraction(augmented_waveform, sr=16000) # Adjusted sample rate to 16kHz
with torch.no_grad():
outputs_original = model(original_features)
outputs_augmented = model(augmented_features)
logits = (outputs_original.logits + outputs_augmented.logits) / 2
predicted_index = logits.argmax()
original_label = model.config.id2label[predicted_index.item()]
confidence = torch.softmax(logits, dim=1).max().item() * 100
label_mapping = {"Spoof": "AI-generated Clone", "Bonafide": "Real Human Voice"}
new_label = label_mapping.get(original_label, "Unknown")
waveform_plot = plot_waveform(waveform, 16000) # Adjusted sample rate to 16kHz
spectrogram_plot = plot_spectrogram(waveform, 16000) # Adjusted sample rate to 16kHz
return (f"The voice is classified as '{new_label}' with a confidence of {confidence:.2f}%.",
waveform_plot,
spectrogram_plot)
with gr.Blocks(css="style.css") as demo:
gr.Markdown("## Voice Clone Detection")
gr.Markdown("Detects whether a voice is real or an AI-generated clone. Upload an audio file to see the results.")
with gr.Row():
audio_input = gr.Audio(label="Upload Audio File", type="filepath")
detect_button = gr.Button("Detect Voice Clone")
with gr.Row():
prediction_output = gr.Textbox(label="Prediction")
with gr.Row():
waveform_output = gr.Image(label="Waveform")
spectrogram_output = gr.Image(label="Spectrogram")
detect_button.click(fn=predict_voice, inputs=[audio_input], outputs=[prediction_output, waveform_output, spectrogram_output])
demo.launch()