Kabatubare's picture
Update app.py
aaa09e2 verified
raw
history blame
No virus
4.84 kB
import gradio as gr
import librosa
import numpy as np
import torch
import matplotlib.pyplot as plt
from transformers import AutoModelForAudioClassification, ASTFeatureExtractor, Wav2Vec2Processor, Wav2Vec2ForCTC
import random
import tempfile
# Load Wav2Vec 2.0 models
wav2vec_processor = Wav2Vec2Processor.from_pretrained("facebook/wav2vec2-base-960h")
wav2vec_model = Wav2Vec2ForCTC.from_pretrained("facebook/wav2vec2-base-960h")
# Original model and feature extractor loading
model = AutoModelForAudioClassification.from_pretrained("./")
feature_extractor = ASTFeatureExtractor.from_pretrained("./")
def plot_waveform(waveform, sr):
plt.figure(figsize=(12, 4))
plt.title('Waveform')
plt.ylabel('Amplitude')
plt.plot(np.linspace(0, len(waveform) / sr, len(waveform)), waveform)
plt.xlabel('Time (s)')
temp_file = tempfile.NamedTemporaryFile(delete=False, suffix='.png', dir='./')
plt.savefig(temp_file.name)
plt.close()
return temp_file.name
def plot_spectrogram(waveform, sr):
S = librosa.feature.melspectrogram(y=waveform, sr=sr, n_mels=128)
S_DB = librosa.power_to_db(S, ref=np.max)
plt.figure(figsize=(12, 6))
librosa.display.specshow(S_DB, sr=sr, x_axis='time', y_axis='mel', cmap='inferno')
plt.title('Mel Spectrogram')
plt.colorbar(format='%+2.0f dB')
plt.tight_layout()
temp_file = tempfile.NamedTemporaryFile(delete=False, suffix='.png', dir='./')
plt.savefig(temp_file.name)
plt.close()
return temp_file.name
def custom_feature_extraction(audio, sr=16000, target_length=1024):
features = feature_extractor(audio, sampling_rate=sr, return_tensors="pt", padding="max_length", max_length=target_length)
return features.input_values
def apply_time_shift(waveform, max_shift_fraction=0.1):
shift = random.randint(-int(max_shift_fraction * len(waveform)), int(max_shift_fraction * len(waveform)))
return np.roll(waveform, shift)
def transcribe_audio(audio_file_path):
waveform, _ = librosa.load(audio_file_path, sr=wav2vec_processor.feature_extractor.sampling_rate, mono=True)
input_values = wav2vec_processor(waveform, return_tensors="pt", padding="longest").input_values
with torch.no_grad():
logits = wav2vec_model(input_values).logits
predicted_ids = torch.argmax(logits, dim=-1)
transcription = wav2vec_processor.batch_decode(predicted_ids)
return transcription
def predict_voice(audio_file_path):
try:
transcription = transcribe_audio(audio_file_path)
waveform, sample_rate = librosa.load(audio_file_path, sr=feature_extractor.sampling_rate, mono=True)
augmented_waveform = apply_time_shift(waveform)
original_features = custom_feature_extraction(waveform, sr=sample_rate)
augmented_features = custom_feature_extraction(augmented_waveform, sr=sample_rate)
with torch.no_grad():
outputs_original = model(original_features)
outputs_augmented = model(augmented_features)
logits = (outputs_original.logits + outputs_augmented.logits) / 2
predicted_index = logits.argmax()
original_label = model.config.id2label[predicted_index.item()]
confidence = torch.softmax(logits, dim=1).max().item() * 100
label_mapping = {
"Spoof": "AI-generated Clone",
"Bonafide": "Real Human Voice"
}
new_label = label_mapping.get(original_label, "Unknown")
waveform_plot = plot_waveform(waveform, sample_rate)
spectrogram_plot = plot_spectrogram(waveform, sample_rate)
return (
f"The voice is classified as '{new_label}' with a confidence of {confidence:.2f}%.",
waveform_plot,
spectrogram_plot,
transcription[0] # Assuming transcription returns a list with a single string
)
except Exception as e:
return f"Error during processing: {e}", None, None, ""
with gr.Blocks(css="style.css") as demo:
gr.Markdown("## Voice Clone Detection")
gr.Markdown("Detects whether a voice is real or an AI-generated clone. Upload an audio file to see the results.")
with gr.Row():
audio_input = gr.Audio(label="Upload Audio File", type="filepath")
with gr.Row():
prediction_output = gr.Textbox(label="Prediction")
transcription_output = gr.Textbox(label="Transcription") # Fixed indentation
waveform_output = gr.Image(label="Waveform")
spectrogram_output = gr.Image(label="Spectrogram")
detect_button = gr.Button("Detect Voice Clone")
detect_button.click(
fn=predict_voice,
inputs=[audio_input],
outputs=[prediction_output, waveform_output, spectrogram_output, transcription_output]
)
# Launch the interface
demo.launch()