File size: 4,839 Bytes
fe0bcff
 
dfabd2f
30a5efb
c7ab812
 
fe0bcff
c7ab812
f0dd070
c7ab812
 
 
 
 
fe0bcff
84de51b
ee91d94
c7ab812
 
 
 
 
 
 
 
 
 
 
 
 
 
 
01ce6f2
c7ab812
 
 
 
 
 
 
 
 
84de51b
 
0c35856
fe0bcff
 
 
411539a
c7ab812
 
 
 
 
 
 
 
 
50facbf
411539a
c7ab812
 
84de51b
fe0bcff
84de51b
a29043b
fe0bcff
84de51b
53b1abc
fe0bcff
 
84de51b
fe0bcff
 
c7ab812
fe0bcff
c7ab812
 
 
 
 
 
 
 
 
 
 
93984af
c7ab812
93984af
 
c7ab812
e8e81bf
93984af
ee91d94
8676909
01ce6f2
 
8676909
01ce6f2
 
fe0bcff
01ce6f2
 
aaa09e2
01ce6f2
 
6781020
01ce6f2
 
 
8676909
01ce6f2
 
6781020
8676909
aaa09e2
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
import gradio as gr
import librosa
import numpy as np
import torch
import matplotlib.pyplot as plt
from transformers import AutoModelForAudioClassification, ASTFeatureExtractor, Wav2Vec2Processor, Wav2Vec2ForCTC
import random
import tempfile

# Load Wav2Vec 2.0 models
wav2vec_processor = Wav2Vec2Processor.from_pretrained("facebook/wav2vec2-base-960h")
wav2vec_model = Wav2Vec2ForCTC.from_pretrained("facebook/wav2vec2-base-960h")

# Original model and feature extractor loading
model = AutoModelForAudioClassification.from_pretrained("./")
feature_extractor = ASTFeatureExtractor.from_pretrained("./")

def plot_waveform(waveform, sr):
    plt.figure(figsize=(12, 4))
    plt.title('Waveform')
    plt.ylabel('Amplitude')
    plt.plot(np.linspace(0, len(waveform) / sr, len(waveform)), waveform)
    plt.xlabel('Time (s)')
    temp_file = tempfile.NamedTemporaryFile(delete=False, suffix='.png', dir='./')
    plt.savefig(temp_file.name)
    plt.close()
    return temp_file.name

def plot_spectrogram(waveform, sr):
    S = librosa.feature.melspectrogram(y=waveform, sr=sr, n_mels=128)
    S_DB = librosa.power_to_db(S, ref=np.max)
    plt.figure(figsize=(12, 6))
    librosa.display.specshow(S_DB, sr=sr, x_axis='time', y_axis='mel', cmap='inferno')
    plt.title('Mel Spectrogram')
    plt.colorbar(format='%+2.0f dB')
    plt.tight_layout()
    temp_file = tempfile.NamedTemporaryFile(delete=False, suffix='.png', dir='./')
    plt.savefig(temp_file.name)
    plt.close()
    return temp_file.name

def custom_feature_extraction(audio, sr=16000, target_length=1024):
    features = feature_extractor(audio, sampling_rate=sr, return_tensors="pt", padding="max_length", max_length=target_length)
    return features.input_values

def apply_time_shift(waveform, max_shift_fraction=0.1):
    shift = random.randint(-int(max_shift_fraction * len(waveform)), int(max_shift_fraction * len(waveform)))
    return np.roll(waveform, shift)

def transcribe_audio(audio_file_path):
    waveform, _ = librosa.load(audio_file_path, sr=wav2vec_processor.feature_extractor.sampling_rate, mono=True)
    input_values = wav2vec_processor(waveform, return_tensors="pt", padding="longest").input_values
    with torch.no_grad():
        logits = wav2vec_model(input_values).logits
    predicted_ids = torch.argmax(logits, dim=-1)
    transcription = wav2vec_processor.batch_decode(predicted_ids)
    return transcription

def predict_voice(audio_file_path):
    try:
        transcription = transcribe_audio(audio_file_path)
        
        waveform, sample_rate = librosa.load(audio_file_path, sr=feature_extractor.sampling_rate, mono=True)
        augmented_waveform = apply_time_shift(waveform)
        
        original_features = custom_feature_extraction(waveform, sr=sample_rate)
        augmented_features = custom_feature_extraction(augmented_waveform, sr=sample_rate)
        
        with torch.no_grad():
            outputs_original = model(original_features)
            outputs_augmented = model(augmented_features)
        
        logits = (outputs_original.logits + outputs_augmented.logits) / 2
        predicted_index = logits.argmax()
        original_label = model.config.id2label[predicted_index.item()]
        confidence = torch.softmax(logits, dim=1).max().item() * 100
        
        label_mapping = {
            "Spoof": "AI-generated Clone",
            "Bonafide": "Real Human Voice"
        }
        new_label = label_mapping.get(original_label, "Unknown")
        
        waveform_plot = plot_waveform(waveform, sample_rate)
        spectrogram_plot = plot_spectrogram(waveform, sample_rate)
        
        return (
            f"The voice is classified as '{new_label}' with a confidence of {confidence:.2f}%.",
            waveform_plot,
            spectrogram_plot,
            transcription[0]  # Assuming transcription returns a list with a single string
        )
    except Exception as e:
        return f"Error during processing: {e}", None, None, ""

with gr.Blocks(css="style.css") as demo:
    gr.Markdown("## Voice Clone Detection")
    gr.Markdown("Detects whether a voice is real or an AI-generated clone. Upload an audio file to see the results.")
    
    with gr.Row():
        audio_input = gr.Audio(label="Upload Audio File", type="filepath")

    with gr.Row():
        prediction_output = gr.Textbox(label="Prediction")
        transcription_output = gr.Textbox(label="Transcription")  # Fixed indentation
        waveform_output = gr.Image(label="Waveform")
        spectrogram_output = gr.Image(label="Spectrogram")

    detect_button = gr.Button("Detect Voice Clone")
    detect_button.click(
        fn=predict_voice,
        inputs=[audio_input],
        outputs=[prediction_output, waveform_output, spectrogram_output, transcription_output]
    )

# Launch the interface
demo.launch()