DavidCombei commited on
Commit
24e3f01
·
verified ·
1 Parent(s): 01c10d0

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +89 -90
app.py CHANGED
@@ -1,90 +1,89 @@
1
- import joblib
2
- from transformers import AutoFeatureExtractor, Wav2Vec2Model
3
- import torch
4
- import librosa
5
- import numpy as np
6
- from sklearn.linear_model import LogisticRegression
7
- import gradio as gr
8
- import os
9
-
10
- device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
11
-
12
- class CustomWav2Vec2Model(Wav2Vec2Model):
13
- def __init__(self, config):
14
- super().__init__(config)
15
- self.encoder.layers = self.encoder.layers[:9]
16
-
17
- truncated_model = CustomWav2Vec2Model.from_pretrained("facebook/wav2vec2-xls-r-2b")
18
-
19
- class HuggingFaceFeatureExtractor:
20
- def __init__(self, model, feature_extractor_name):
21
- self.device = device
22
- self.feature_extractor = AutoFeatureExtractor.from_pretrained(feature_extractor_name)
23
- self.model = model
24
- self.model.eval()
25
- self.model.to(self.device)
26
-
27
- def __call__(self, audio, sr):
28
- inputs = self.feature_extractor(
29
- audio,
30
- sampling_rate=sr,
31
- return_tensors="pt",
32
- padding=True,
33
- )
34
- inputs = {k: v.to(self.device) for k, v in inputs.items()}
35
- with torch.no_grad():
36
- outputs = self.model(**inputs, output_hidden_states=True)
37
- return outputs.hidden_states[9]
38
-
39
- FEATURE_EXTRACTOR = HuggingFaceFeatureExtractor(truncated_model, "facebook/wav2vec2-xls-r-2b")
40
- classifier,scaler, thresh = joblib.load('logreg_margin_pruning_ALL_with_scaler+threshold.joblib')
41
-
42
- def segment_audio(audio, sr, segment_duration):
43
- segment_samples = int(segment_duration * sr)
44
- total_samples = len(audio)
45
- segments = [audio[i:i + segment_samples] for i in range(0, total_samples, segment_samples)]
46
- return segments
47
-
48
- def process_audio(input_data, segment_duration=10):
49
- audio, sr = librosa.load(input_data, sr=16000)
50
- if len(audio.shape) > 1:
51
- audio = audio[0]
52
- segments = segment_audio(audio, sr, segment_duration)
53
- segment_predictions = []
54
- output_lines = []
55
- eer_threshold = thresh - 5e5 # small margin error due to feature extractor space differences
56
- for idx, segment in enumerate(segments):
57
- features = FEATURE_EXTRACTOR(segment, sr)
58
- features_avg = torch.mean(features, dim=1).cpu().numpy()
59
- features_avg = features_avg.reshape(1, -1)
60
- decision_score = classifier.decision_function(features_avg)
61
- decision_score_scaled = scaler.transform(decision_score.reshape(-1, 1)).flatten()
62
- if decision_score_scaled >= eer_threshold:
63
- pred = 1
64
- confidence_percentage = decision_score_scaled[0] * 100
65
- else:
66
- pred = 0
67
- confidence_percentage = (1 - decision_score_scaled[0]) * 100
68
- segment_predictions.append(pred)
69
- line = f"Segment {idx + 1}: {'Real' if pred == 1 else 'Fake'} (Confidence: {round(confidence_percentage, 2)}%)"
70
- output_lines.append(line)
71
- overall_prediction = 1 if sum(segment_predictions) > (len(segment_predictions) / 2) else 0
72
- overall_line = f"Overall Prediction: {'Real' if overall_prediction == 1 else 'Fake'}"
73
- output_str = overall_line + "\n" + "\n".join(output_lines)
74
- return output_str
75
-
76
- def gradio_interface(audio):
77
- if audio:
78
- return process_audio(audio)
79
- else:
80
- return "please upload an audio file"
81
-
82
- interface = gr.Interface(
83
- fn=gradio_interface,
84
- inputs=[gr.Audio(type="filepath", label="Upload Audio")],
85
- outputs="text",
86
- title="SOL2 Audio Deepfake Detection Demo",
87
- description="Upload an audio file to check if it's AI-generated",
88
- )
89
-
90
- interface.launch(share=True)
 
1
+ import joblib
2
+ from transformers import AutoFeatureExtractor, Wav2Vec2Model
3
+ import torch
4
+ import librosa
5
+ import numpy as np
6
+ from sklearn.linear_model import LogisticRegression
7
+ import gradio as gr
8
+ import os
9
+
10
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
11
+
12
+ class CustomWav2Vec2Model(Wav2Vec2Model):
13
+ def __init__(self, config):
14
+ super().__init__(config)
15
+ self.encoder.layers = self.encoder.layers[:9]
16
+
17
+ truncated_model = CustomWav2Vec2Model.from_pretrained("facebook/wav2vec2-xls-r-2b")
18
+
19
+ class HuggingFaceFeatureExtractor:
20
+ def __init__(self, model, feature_extractor_name):
21
+ self.device = device
22
+ self.feature_extractor = AutoFeatureExtractor.from_pretrained(feature_extractor_name)
23
+ self.model = model
24
+ self.model.eval()
25
+ self.model.to(self.device)
26
+
27
+ def __call__(self, audio, sr):
28
+ inputs = self.feature_extractor(
29
+ audio,
30
+ sampling_rate=sr,
31
+ return_tensors="pt",
32
+ padding=True,
33
+ )
34
+ inputs = {k: v.to(self.device) for k, v in inputs.items()}
35
+ with torch.no_grad():
36
+ outputs = self.model(**inputs, output_hidden_states=True)
37
+ return outputs.hidden_states[9]
38
+
39
+ FEATURE_EXTRACTOR = HuggingFaceFeatureExtractor(truncated_model, "facebook/wav2vec2-xls-r-2b")
40
+ classifier,scaler, thresh = joblib.load('logreg_margin_pruning_ALL_with_scaler+threshold.joblib')
41
+
42
+ def segment_audio(audio, sr, segment_duration):
43
+ segment_samples = int(segment_duration * sr)
44
+ total_samples = len(audio)
45
+ segments = [audio[i:i + segment_samples] for i in range(0, total_samples, segment_samples)]
46
+ return segments
47
+
48
+ def process_audio(input_data, segment_duration=10):
49
+ audio, sr = librosa.load(input_data, sr=16000)
50
+ if len(audio.shape) > 1:
51
+ audio = audio[0]
52
+ segments = segment_audio(audio, sr, segment_duration)
53
+ segment_predictions = []
54
+ output_lines = []
55
+ eer_threshold = thresh - 5e-3 # small margin error due to feature extractor space differences
56
+ for idx, segment in enumerate(segments):
57
+ features = FEATURE_EXTRACTOR(segment, sr)
58
+ features_avg = torch.mean(features, dim=1).cpu().numpy()
59
+ features_avg = features_avg.reshape(1, -1)
60
+ decision_score = classifier.decision_function(features_avg)
61
+ decision_score_scaled = scaler.transform(decision_score.reshape(-1, 1)).flatten()
62
+ pred = 1 if decision_value >= eer_threshold else 0
63
+ if pred == 1:
64
+ confidence_percentage = ((decision_score_scaled - eer_threshold) / (1 - eer_threshold)) * 100
65
+ else:
66
+ confidence_percentage = ((eer_threshold - decision_score_scaled) / eer_threshold) * 100
67
+ segment_predictions.append(pred)
68
+ line = f"Segment {idx + 1}: {'Real' if pred == 1 else 'Fake'} (Confidence: {round(confidence_percentage, 2)}%)"
69
+ output_lines.append(line)
70
+ overall_prediction = 1 if sum(segment_predictions) > (len(segment_predictions) / 2) else 0
71
+ overall_line = f"Overall Prediction: {'Real' if overall_prediction == 1 else 'Fake'}"
72
+ output_str = overall_line + "\n" + "\n".join(output_lines)
73
+ return output_str
74
+
75
+ def gradio_interface(audio):
76
+ if audio:
77
+ return process_audio(audio)
78
+ else:
79
+ return "please upload an audio file"
80
+
81
+ interface = gr.Interface(
82
+ fn=gradio_interface,
83
+ inputs=[gr.Audio(type="filepath", label="Upload Audio")],
84
+ outputs="text",
85
+ title="SOL2 Audio Deepfake Detection Demo",
86
+ description="Upload an audio file to check if it's AI-generated",
87
+ )
88
+
89
+ interface.launch(share=True)