Kabatubare commited on
Commit
9ec21ae
1 Parent(s): b8277b5

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +45 -98
app.py CHANGED
@@ -3,61 +3,43 @@ import librosa
3
  import numpy as np
4
  import torch
5
  import matplotlib.pyplot as plt
6
- from transformers import AutoModelForAudioClassification, ASTFeatureExtractor, Wav2Vec2Processor, Wav2Vec2ForCTC
7
  import random
8
  import tempfile
9
  import logging
 
10
 
11
  logging.basicConfig(level=logging.DEBUG, filename='app.log', filemode='w', format='%(name)s - %(levelname)s - %(message)s')
12
  logger = logging.getLogger(__name__)
13
 
14
- # Load Wav2Vec 2.0 models
15
- wav2vec_processor = Wav2Vec2Processor.from_pretrained("facebook/wav2vec2-base-960h")
16
- wav2vec_model = Wav2Vec2ForCTC.from_pretrained("facebook/wav2vec2-base-960h")
17
-
18
- # Original model and feature extractor loading
19
  model = AutoModelForAudioClassification.from_pretrained("./")
20
  feature_extractor = ASTFeatureExtractor.from_pretrained("./")
21
 
22
  def plot_waveform(waveform, sr):
23
- try:
24
- plt.figure(figsize=(12, 4))
25
- plt.title('Waveform')
26
- plt.ylabel('Amplitude')
27
- plt.plot(np.linspace(0, len(waveform) / sr, len(waveform)), waveform)
28
- plt.xlabel('Time (s)')
29
- temp_file = tempfile.NamedTemporaryFile(delete=False, suffix='.png', dir='./')
30
- plt.savefig(temp_file.name)
31
- plt.close()
32
-
33
- file_size = os.path.getsize(temp_file.name)
34
- logger.debug(f"Waveform image generated: {temp_file.name}, Size: {file_size} bytes")
35
-
36
- return temp_file.name
37
- except Exception as e:
38
- logger.error(f"Error generating waveform image: {e}")
39
- raise
40
 
41
  def plot_spectrogram(waveform, sr):
42
- try:
43
- S = librosa.feature.melspectrogram(y=waveform, sr=sr, n_mels=128)
44
- S_DB = librosa.power_to_db(S, ref=np.max)
45
- plt.figure(figsize=(12, 6))
46
- librosa.display.specshow(S_DB, sr=sr, x_axis='time', y_axis='mel')
47
- plt.title('Mel Spectrogram')
48
- plt.colorbar(format='%+2.0f dB')
49
- plt.tight_layout()
50
- temp_file = tempfile.NamedTemporaryFile(delete=False, suffix='.png', dir='./')
51
- plt.savefig(temp_file.name)
52
- plt.close()
53
-
54
- file_size = os.path.getsize(temp_file.name)
55
- logger.debug(f"Spectrogram image generated: {temp_file.name}, Size: {file_size} bytes")
56
-
57
- return temp_file.name
58
- except Exception as e:
59
- logger.error(f"Error generating spectrogram image: {e}")
60
- raise
61
 
62
  def custom_feature_extraction(audio, sr=16000, target_length=1024):
63
  features = feature_extractor(audio, sampling_rate=sr, return_tensors="pt", padding="max_length", max_length=target_length)
@@ -67,72 +49,37 @@ def apply_time_shift(waveform, max_shift_fraction=0.1):
67
  shift = random.randint(-int(max_shift_fraction * len(waveform)), int(max_shift_fraction * len(waveform)))
68
  return np.roll(waveform, shift)
69
 
70
- def transcribe_audio(audio_file_path):
71
- waveform, _ = librosa.load(audio_file_path, sr=wav2vec_processor.feature_extractor.sampling_rate, mono=True)
72
- input_values = wav2vec_processor(waveform, return_tensors="pt", padding="longest").input_values
73
- with torch.no_grad():
74
- logits = wav2vec_model(input_values).logits
75
- predicted_ids = torch.argmax(logits, dim=-1)
76
- transcription = wav2vec_processor.batch_decode(predicted_ids)
77
- return transcription
78
-
79
  def predict_voice(audio_file_path):
80
- try:
81
- transcription = transcribe_audio(audio_file_path)
82
-
83
- waveform, sample_rate = librosa.load(audio_file_path, sr=feature_extractor.sampling_rate, mono=True)
84
- augmented_waveform = apply_time_shift(waveform)
85
-
86
- original_features = custom_feature_extraction(waveform, sr=sample_rate)
87
- augmented_features = custom_feature_extraction(augmented_waveform, sr=sample_rate)
88
-
89
- with torch.no_grad():
90
- outputs_original = model(original_features)
91
- outputs_augmented = model(augmented_features)
92
-
93
- logits = (outputs_original.logits + outputs_augmented.logits) / 2
94
- predicted_index = logits.argmax()
95
- original_label = model.config.id2label[predicted_index.item()]
96
- confidence = torch.softmax(logits, dim=1).max().item() * 100
97
-
98
- label_mapping = {
99
- "Spoof": "AI-generated Clone",
100
- "Bonafide": "Real Human Voice"
101
- }
102
- new_label = label_mapping.get(original_label, "Unknown")
103
-
104
- waveform_plot = plot_waveform(waveform, sample_rate)
105
- spectrogram_plot = plot_spectrogram(waveform, sample_rate)
106
-
107
- return (
108
- f"The voice is classified as '{new_label}' with a confidence of {confidence:.2f}%.",
109
  waveform_plot,
110
- spectrogram_plot,
111
- transcription[0] # Assuming transcription returns a list with a single string
112
- )
113
- except Exception as e:
114
- logger.error(f"Error during voice prediction: {e}")
115
- return f"Error during processing: {e}", None, None, ""
116
 
117
  with gr.Blocks(css="style.css") as demo:
118
  gr.Markdown("## Voice Clone Detection")
119
  gr.Markdown("Detects whether a voice is real or an AI-generated clone. Upload an audio file to see the results.")
120
-
121
  with gr.Row():
122
  audio_input = gr.Audio(label="Upload Audio File", type="filepath")
123
-
124
  with gr.Row():
125
  prediction_output = gr.Textbox(label="Prediction")
126
- transcription_output = gr.Textbox(label="Transcription") # Fixed indentation
127
  waveform_output = gr.Image(label="Waveform")
128
  spectrogram_output = gr.Image(label="Spectrogram")
 
129
 
130
- detect_button = gr.Button("Detect Voice Clone")
131
- detect_button.click(
132
- fn=predict_voice,
133
- inputs=[audio_input],
134
- outputs=[prediction_output, waveform_output, spectrogram_output, transcription_output]
135
- )
136
-
137
- # Launch the interface
138
- demo.launch()
 
3
  import numpy as np
4
  import torch
5
  import matplotlib.pyplot as plt
6
+ from transformers import AutoModelForAudioClassification, ASTFeatureExtractor
7
  import random
8
  import tempfile
9
  import logging
10
+ import os
11
 
12
  logging.basicConfig(level=logging.DEBUG, filename='app.log', filemode='w', format='%(name)s - %(levelname)s - %(message)s')
13
  logger = logging.getLogger(__name__)
14
 
 
 
 
 
 
15
  model = AutoModelForAudioClassification.from_pretrained("./")
16
  feature_extractor = ASTFeatureExtractor.from_pretrained("./")
17
 
18
  def plot_waveform(waveform, sr):
19
+ plt.figure(figsize=(24, 8)) # Doubled size for larger visuals
20
+ plt.title('Waveform')
21
+ plt.ylabel('Amplitude')
22
+ plt.plot(np.linspace(0, len(waveform) / sr, len(waveform)), waveform)
23
+ plt.xlabel('Time (s)')
24
+ temp_file = tempfile.NamedTemporaryFile(delete=False, suffix='.png', dir='./')
25
+ plt.savefig(temp_file.name)
26
+ plt.close()
27
+ logger.debug(f"Waveform image generated: {temp_file.name}, Size: {os.path.getsize(temp_file.name)} bytes")
28
+ return temp_file.name
 
 
 
 
 
 
 
29
 
30
  def plot_spectrogram(waveform, sr):
31
+ S = librosa.feature.melspectrogram(y=waveform, sr=sr, n_mels=128)
32
+ S_DB = librosa.power_to_db(S, ref=np.max)
33
+ plt.figure(figsize=(24, 12)) # Doubled size for larger visuals
34
+ librosa.display.specshow(S_DB, sr=sr, x_axis='time', y_axis='mel')
35
+ plt.title('Mel Spectrogram')
36
+ plt.colorbar(format='%+2.0f dB')
37
+ plt.tight_layout()
38
+ temp_file = tempfile.NamedTemporaryFile(delete=False, suffix='.png', dir='./')
39
+ plt.savefig(temp_file.name)
40
+ plt.close()
41
+ logger.debug(f"Spectrogram image generated: {temp_file.name}, Size: {os.path.getsize(temp_file.name)} bytes")
42
+ return temp_file.name
 
 
 
 
 
 
 
43
 
44
  def custom_feature_extraction(audio, sr=16000, target_length=1024):
45
  features = feature_extractor(audio, sampling_rate=sr, return_tensors="pt", padding="max_length", max_length=target_length)
 
49
  shift = random.randint(-int(max_shift_fraction * len(waveform)), int(max_shift_fraction * len(waveform)))
50
  return np.roll(waveform, shift)
51
 
 
 
 
 
 
 
 
 
 
52
  def predict_voice(audio_file_path):
53
+ waveform, _ = librosa.load(audio_file_path, sr=16000, mono=True) # Ensure all audio is resampled to 16kHz
54
+ augmented_waveform = apply_time_shift(waveform)
55
+ original_features = custom_feature_extraction(waveform, sr=16000) # Adjusted sample rate to 16kHz
56
+ augmented_features = custom_feature_extraction(augmented_waveform, sr=16000) # Adjusted sample rate to 16kHz
57
+ with torch.no_grad():
58
+ outputs_original = model(original_features)
59
+ outputs_augmented = model(augmented_features)
60
+ logits = (outputs_original.logits + outputs_augmented.logits) / 2
61
+ predicted_index = logits.argmax()
62
+ original_label = model.config.id2label[predicted_index.item()]
63
+ confidence = torch.softmax(logits, dim=1).max().item() * 100
64
+ label_mapping = {"Spoof": "AI-generated Clone", "Bonafide": "Real Human Voice"}
65
+ new_label = label_mapping.get(original_label, "Unknown")
66
+ waveform_plot = plot_waveform(waveform, 16000) # Adjusted sample rate to 16kHz
67
+ spectrogram_plot = plot_spectrogram(waveform, 16000) # Adjusted sample rate to 16kHz
68
+ return (f"The voice is classified as '{new_label}' with a confidence of {confidence:.2f}%.",
 
 
 
 
 
 
 
 
 
 
 
 
 
69
  waveform_plot,
70
+ spectrogram_plot)
 
 
 
 
 
71
 
72
  with gr.Blocks(css="style.css") as demo:
73
  gr.Markdown("## Voice Clone Detection")
74
  gr.Markdown("Detects whether a voice is real or an AI-generated clone. Upload an audio file to see the results.")
 
75
  with gr.Row():
76
  audio_input = gr.Audio(label="Upload Audio File", type="filepath")
77
+ detect_button = gr.Button("Detect Voice Clone")
78
  with gr.Row():
79
  prediction_output = gr.Textbox(label="Prediction")
80
+ with gr.Row():
81
  waveform_output = gr.Image(label="Waveform")
82
  spectrogram_output = gr.Image(label="Spectrogram")
83
+ detect_button.click(fn=predict_voice, inputs=[audio_input], outputs=[prediction_output, waveform_output, spectrogram_output])
84
 
85
+ demo.launch()