File size: 4,097 Bytes
fe0bcff
 
dfabd2f
30a5efb
c7ab812
9ec21ae
fe0bcff
c7ab812
cbd878e
9ec21ae
cbd878e
 
 
f0dd070
fe0bcff
84de51b
ee91d94
c7ab812
9ec21ae
 
 
 
 
 
 
 
 
 
c7ab812
 
9ec21ae
 
 
 
 
 
 
 
 
 
 
 
c7ab812
 
84de51b
 
0c35856
fe0bcff
 
 
411539a
50facbf
9ec21ae
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c7ab812
9ec21ae
ee91d94
8676909
01ce6f2
 
 
 
9ec21ae
01ce6f2
 
9ec21ae
01ce6f2
 
9ec21ae
6781020
9ec21ae
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
import gradio as gr
import librosa
import numpy as np
import torch
import matplotlib.pyplot as plt
from transformers import AutoModelForAudioClassification, ASTFeatureExtractor
import random
import tempfile
import logging
import os

logging.basicConfig(level=logging.DEBUG, filename='app.log', filemode='w', format='%(name)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)

model = AutoModelForAudioClassification.from_pretrained("./")
feature_extractor = ASTFeatureExtractor.from_pretrained("./")

def plot_waveform(waveform, sr):
    plt.figure(figsize=(24, 8))  # Doubled size for larger visuals
    plt.title('Waveform')
    plt.ylabel('Amplitude')
    plt.plot(np.linspace(0, len(waveform) / sr, len(waveform)), waveform)
    plt.xlabel('Time (s)')
    temp_file = tempfile.NamedTemporaryFile(delete=False, suffix='.png', dir='./')
    plt.savefig(temp_file.name)
    plt.close()
    logger.debug(f"Waveform image generated: {temp_file.name}, Size: {os.path.getsize(temp_file.name)} bytes")
    return temp_file.name

def plot_spectrogram(waveform, sr):
    S = librosa.feature.melspectrogram(y=waveform, sr=sr, n_mels=128)
    S_DB = librosa.power_to_db(S, ref=np.max)
    plt.figure(figsize=(24, 12))  # Doubled size for larger visuals
    librosa.display.specshow(S_DB, sr=sr, x_axis='time', y_axis='mel')
    plt.title('Mel Spectrogram')
    plt.colorbar(format='%+2.0f dB')
    plt.tight_layout()
    temp_file = tempfile.NamedTemporaryFile(delete=False, suffix='.png', dir='./')
    plt.savefig(temp_file.name)
    plt.close()
    logger.debug(f"Spectrogram image generated: {temp_file.name}, Size: {os.path.getsize(temp_file.name)} bytes")
    return temp_file.name

def custom_feature_extraction(audio, sr=16000, target_length=1024):
    features = feature_extractor(audio, sampling_rate=sr, return_tensors="pt", padding="max_length", max_length=target_length)
    return features.input_values

def apply_time_shift(waveform, max_shift_fraction=0.1):
    shift = random.randint(-int(max_shift_fraction * len(waveform)), int(max_shift_fraction * len(waveform)))
    return np.roll(waveform, shift)

def predict_voice(audio_file_path):
    waveform, _ = librosa.load(audio_file_path, sr=16000, mono=True)  # Ensure all audio is resampled to 16kHz
    augmented_waveform = apply_time_shift(waveform)
    original_features = custom_feature_extraction(waveform, sr=16000)  # Adjusted sample rate to 16kHz
    augmented_features = custom_feature_extraction(augmented_waveform, sr=16000)  # Adjusted sample rate to 16kHz
    with torch.no_grad():
        outputs_original = model(original_features)
        outputs_augmented = model(augmented_features)
    logits = (outputs_original.logits + outputs_augmented.logits) / 2
    predicted_index = logits.argmax()
    original_label = model.config.id2label[predicted_index.item()]
    confidence = torch.softmax(logits, dim=1).max().item() * 100
    label_mapping = {"Spoof": "AI-generated Clone", "Bonafide": "Real Human Voice"}
    new_label = label_mapping.get(original_label, "Unknown")
    waveform_plot = plot_waveform(waveform, 16000)  # Adjusted sample rate to 16kHz
    spectrogram_plot = plot_spectrogram(waveform, 16000)  # Adjusted sample rate to 16kHz
    return (f"The voice is classified as '{new_label}' with a confidence of {confidence:.2f}%.",
            waveform_plot,
            spectrogram_plot)

with gr.Blocks(css="style.css") as demo:
    gr.Markdown("## Voice Clone Detection")
    gr.Markdown("Detects whether a voice is real or an AI-generated clone. Upload an audio file to see the results.")
    with gr.Row():
        audio_input = gr.Audio(label="Upload Audio File", type="filepath")
    detect_button = gr.Button("Detect Voice Clone")
    with gr.Row():
        prediction_output = gr.Textbox(label="Prediction")
    with gr.Row():
        waveform_output = gr.Image(label="Waveform")
        spectrogram_output = gr.Image(label="Spectrogram")
    detect_button.click(fn=predict_voice, inputs=[audio_input], outputs=[prediction_output, waveform_output, spectrogram_output])

demo.launch()