Kabatubare
commited on
Commit
•
9ec21ae
1
Parent(s):
b8277b5
Update app.py
Browse files
app.py
CHANGED
@@ -3,61 +3,43 @@ import librosa
|
|
3 |
import numpy as np
|
4 |
import torch
|
5 |
import matplotlib.pyplot as plt
|
6 |
-
from transformers import AutoModelForAudioClassification, ASTFeatureExtractor
|
7 |
import random
|
8 |
import tempfile
|
9 |
import logging
|
|
|
10 |
|
11 |
logging.basicConfig(level=logging.DEBUG, filename='app.log', filemode='w', format='%(name)s - %(levelname)s - %(message)s')
|
12 |
logger = logging.getLogger(__name__)
|
13 |
|
14 |
-
# Load Wav2Vec 2.0 models
|
15 |
-
wav2vec_processor = Wav2Vec2Processor.from_pretrained("facebook/wav2vec2-base-960h")
|
16 |
-
wav2vec_model = Wav2Vec2ForCTC.from_pretrained("facebook/wav2vec2-base-960h")
|
17 |
-
|
18 |
-
# Original model and feature extractor loading
|
19 |
model = AutoModelForAudioClassification.from_pretrained("./")
|
20 |
feature_extractor = ASTFeatureExtractor.from_pretrained("./")
|
21 |
|
22 |
def plot_waveform(waveform, sr):
|
23 |
-
|
24 |
-
|
25 |
-
|
26 |
-
|
27 |
-
|
28 |
-
|
29 |
-
|
30 |
-
|
31 |
-
|
32 |
-
|
33 |
-
file_size = os.path.getsize(temp_file.name)
|
34 |
-
logger.debug(f"Waveform image generated: {temp_file.name}, Size: {file_size} bytes")
|
35 |
-
|
36 |
-
return temp_file.name
|
37 |
-
except Exception as e:
|
38 |
-
logger.error(f"Error generating waveform image: {e}")
|
39 |
-
raise
|
40 |
|
41 |
def plot_spectrogram(waveform, sr):
|
42 |
-
|
43 |
-
|
44 |
-
|
45 |
-
|
46 |
-
|
47 |
-
|
48 |
-
|
49 |
-
|
50 |
-
|
51 |
-
|
52 |
-
|
53 |
-
|
54 |
-
file_size = os.path.getsize(temp_file.name)
|
55 |
-
logger.debug(f"Spectrogram image generated: {temp_file.name}, Size: {file_size} bytes")
|
56 |
-
|
57 |
-
return temp_file.name
|
58 |
-
except Exception as e:
|
59 |
-
logger.error(f"Error generating spectrogram image: {e}")
|
60 |
-
raise
|
61 |
|
62 |
def custom_feature_extraction(audio, sr=16000, target_length=1024):
|
63 |
features = feature_extractor(audio, sampling_rate=sr, return_tensors="pt", padding="max_length", max_length=target_length)
|
@@ -67,72 +49,37 @@ def apply_time_shift(waveform, max_shift_fraction=0.1):
|
|
67 |
shift = random.randint(-int(max_shift_fraction * len(waveform)), int(max_shift_fraction * len(waveform)))
|
68 |
return np.roll(waveform, shift)
|
69 |
|
70 |
-
def transcribe_audio(audio_file_path):
|
71 |
-
waveform, _ = librosa.load(audio_file_path, sr=wav2vec_processor.feature_extractor.sampling_rate, mono=True)
|
72 |
-
input_values = wav2vec_processor(waveform, return_tensors="pt", padding="longest").input_values
|
73 |
-
with torch.no_grad():
|
74 |
-
logits = wav2vec_model(input_values).logits
|
75 |
-
predicted_ids = torch.argmax(logits, dim=-1)
|
76 |
-
transcription = wav2vec_processor.batch_decode(predicted_ids)
|
77 |
-
return transcription
|
78 |
-
|
79 |
def predict_voice(audio_file_path):
|
80 |
-
|
81 |
-
|
82 |
-
|
83 |
-
|
84 |
-
|
85 |
-
|
86 |
-
|
87 |
-
|
88 |
-
|
89 |
-
|
90 |
-
|
91 |
-
|
92 |
-
|
93 |
-
|
94 |
-
|
95 |
-
|
96 |
-
confidence = torch.softmax(logits, dim=1).max().item() * 100
|
97 |
-
|
98 |
-
label_mapping = {
|
99 |
-
"Spoof": "AI-generated Clone",
|
100 |
-
"Bonafide": "Real Human Voice"
|
101 |
-
}
|
102 |
-
new_label = label_mapping.get(original_label, "Unknown")
|
103 |
-
|
104 |
-
waveform_plot = plot_waveform(waveform, sample_rate)
|
105 |
-
spectrogram_plot = plot_spectrogram(waveform, sample_rate)
|
106 |
-
|
107 |
-
return (
|
108 |
-
f"The voice is classified as '{new_label}' with a confidence of {confidence:.2f}%.",
|
109 |
waveform_plot,
|
110 |
-
spectrogram_plot
|
111 |
-
transcription[0] # Assuming transcription returns a list with a single string
|
112 |
-
)
|
113 |
-
except Exception as e:
|
114 |
-
logger.error(f"Error during voice prediction: {e}")
|
115 |
-
return f"Error during processing: {e}", None, None, ""
|
116 |
|
117 |
with gr.Blocks(css="style.css") as demo:
|
118 |
gr.Markdown("## Voice Clone Detection")
|
119 |
gr.Markdown("Detects whether a voice is real or an AI-generated clone. Upload an audio file to see the results.")
|
120 |
-
|
121 |
with gr.Row():
|
122 |
audio_input = gr.Audio(label="Upload Audio File", type="filepath")
|
123 |
-
|
124 |
with gr.Row():
|
125 |
prediction_output = gr.Textbox(label="Prediction")
|
126 |
-
|
127 |
waveform_output = gr.Image(label="Waveform")
|
128 |
spectrogram_output = gr.Image(label="Spectrogram")
|
|
|
129 |
|
130 |
-
|
131 |
-
detect_button.click(
|
132 |
-
fn=predict_voice,
|
133 |
-
inputs=[audio_input],
|
134 |
-
outputs=[prediction_output, waveform_output, spectrogram_output, transcription_output]
|
135 |
-
)
|
136 |
-
|
137 |
-
# Launch the interface
|
138 |
-
demo.launch()
|
|
|
3 |
import numpy as np
|
4 |
import torch
|
5 |
import matplotlib.pyplot as plt
|
6 |
+
from transformers import AutoModelForAudioClassification, ASTFeatureExtractor
|
7 |
import random
|
8 |
import tempfile
|
9 |
import logging
|
10 |
+
import os
|
11 |
|
12 |
logging.basicConfig(level=logging.DEBUG, filename='app.log', filemode='w', format='%(name)s - %(levelname)s - %(message)s')
|
13 |
logger = logging.getLogger(__name__)
|
14 |
|
|
|
|
|
|
|
|
|
|
|
15 |
model = AutoModelForAudioClassification.from_pretrained("./")
|
16 |
feature_extractor = ASTFeatureExtractor.from_pretrained("./")
|
17 |
|
18 |
def plot_waveform(waveform, sr):
|
19 |
+
plt.figure(figsize=(24, 8)) # Doubled size for larger visuals
|
20 |
+
plt.title('Waveform')
|
21 |
+
plt.ylabel('Amplitude')
|
22 |
+
plt.plot(np.linspace(0, len(waveform) / sr, len(waveform)), waveform)
|
23 |
+
plt.xlabel('Time (s)')
|
24 |
+
temp_file = tempfile.NamedTemporaryFile(delete=False, suffix='.png', dir='./')
|
25 |
+
plt.savefig(temp_file.name)
|
26 |
+
plt.close()
|
27 |
+
logger.debug(f"Waveform image generated: {temp_file.name}, Size: {os.path.getsize(temp_file.name)} bytes")
|
28 |
+
return temp_file.name
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
29 |
|
30 |
def plot_spectrogram(waveform, sr):
|
31 |
+
S = librosa.feature.melspectrogram(y=waveform, sr=sr, n_mels=128)
|
32 |
+
S_DB = librosa.power_to_db(S, ref=np.max)
|
33 |
+
plt.figure(figsize=(24, 12)) # Doubled size for larger visuals
|
34 |
+
librosa.display.specshow(S_DB, sr=sr, x_axis='time', y_axis='mel')
|
35 |
+
plt.title('Mel Spectrogram')
|
36 |
+
plt.colorbar(format='%+2.0f dB')
|
37 |
+
plt.tight_layout()
|
38 |
+
temp_file = tempfile.NamedTemporaryFile(delete=False, suffix='.png', dir='./')
|
39 |
+
plt.savefig(temp_file.name)
|
40 |
+
plt.close()
|
41 |
+
logger.debug(f"Spectrogram image generated: {temp_file.name}, Size: {os.path.getsize(temp_file.name)} bytes")
|
42 |
+
return temp_file.name
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
43 |
|
44 |
def custom_feature_extraction(audio, sr=16000, target_length=1024):
|
45 |
features = feature_extractor(audio, sampling_rate=sr, return_tensors="pt", padding="max_length", max_length=target_length)
|
|
|
49 |
shift = random.randint(-int(max_shift_fraction * len(waveform)), int(max_shift_fraction * len(waveform)))
|
50 |
return np.roll(waveform, shift)
|
51 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
52 |
def predict_voice(audio_file_path):
|
53 |
+
waveform, _ = librosa.load(audio_file_path, sr=16000, mono=True) # Ensure all audio is resampled to 16kHz
|
54 |
+
augmented_waveform = apply_time_shift(waveform)
|
55 |
+
original_features = custom_feature_extraction(waveform, sr=16000) # Adjusted sample rate to 16kHz
|
56 |
+
augmented_features = custom_feature_extraction(augmented_waveform, sr=16000) # Adjusted sample rate to 16kHz
|
57 |
+
with torch.no_grad():
|
58 |
+
outputs_original = model(original_features)
|
59 |
+
outputs_augmented = model(augmented_features)
|
60 |
+
logits = (outputs_original.logits + outputs_augmented.logits) / 2
|
61 |
+
predicted_index = logits.argmax()
|
62 |
+
original_label = model.config.id2label[predicted_index.item()]
|
63 |
+
confidence = torch.softmax(logits, dim=1).max().item() * 100
|
64 |
+
label_mapping = {"Spoof": "AI-generated Clone", "Bonafide": "Real Human Voice"}
|
65 |
+
new_label = label_mapping.get(original_label, "Unknown")
|
66 |
+
waveform_plot = plot_waveform(waveform, 16000) # Adjusted sample rate to 16kHz
|
67 |
+
spectrogram_plot = plot_spectrogram(waveform, 16000) # Adjusted sample rate to 16kHz
|
68 |
+
return (f"The voice is classified as '{new_label}' with a confidence of {confidence:.2f}%.",
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
69 |
waveform_plot,
|
70 |
+
spectrogram_plot)
|
|
|
|
|
|
|
|
|
|
|
71 |
|
72 |
with gr.Blocks(css="style.css") as demo:
|
73 |
gr.Markdown("## Voice Clone Detection")
|
74 |
gr.Markdown("Detects whether a voice is real or an AI-generated clone. Upload an audio file to see the results.")
|
|
|
75 |
with gr.Row():
|
76 |
audio_input = gr.Audio(label="Upload Audio File", type="filepath")
|
77 |
+
detect_button = gr.Button("Detect Voice Clone")
|
78 |
with gr.Row():
|
79 |
prediction_output = gr.Textbox(label="Prediction")
|
80 |
+
with gr.Row():
|
81 |
waveform_output = gr.Image(label="Waveform")
|
82 |
spectrogram_output = gr.Image(label="Spectrogram")
|
83 |
+
detect_button.click(fn=predict_voice, inputs=[audio_input], outputs=[prediction_output, waveform_output, spectrogram_output])
|
84 |
|
85 |
+
demo.launch()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|