File size: 2,358 Bytes
709b43f
0100779
 
 
 
 
 
709b43f
0100779
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
import gradio as gr
import torch
import torchaudio
import numpy as np
from transformers import AutoFeatureExtractor, AutoModelForAudioClassification
import torch.nn.functional as F
import torchaudio.transforms as T

MODEL_ID = "Zeyadd-Mostaffa/Deepfake-Audio-Detection-v1"

# 1) Load model & feature extractor
feature_extractor = AutoFeatureExtractor.from_pretrained(MODEL_ID)
model = AutoModelForAudioClassification.from_pretrained(MODEL_ID)
model.eval()

# Optionally use GPU if available
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)

label_names = ["fake", "real"]  # According to your label2id = {"fake": 0, "real": 1}

def classify_audio(audio_file):
    """
    audio_file: path to the uploaded file (WAV, MP3, etc.)
    Returns: predicted label and confidence score
    """

    # 2) Load the audio file
    waveform, sr = torchaudio.load(audio_file)

    # If stereo, pick one channel or average
    if waveform.shape[0] > 1:
        waveform = torch.mean(waveform, dim=0, keepdim=True)
    waveform = waveform.squeeze()  # (samples,)

    # 3) Resample if needed
    if sr != 16000:
        resampler = T.Resample(sr, 16000)
        waveform = resampler(waveform)
        sr = 16000

    # 3) Preprocess with feature_extractor
    inputs = feature_extractor(
        waveform.numpy(),
        sampling_rate=sr,
        return_tensors="pt",
        truncation=True,
        max_length=int(16000 * 6.0),  # 6 second max
    )

    # Move everything to device
    input_values = inputs["input_values"].to(device)

    with torch.no_grad():
        logits = model(input_values).logits

        # 4) Calculate probabilities using softmax
        probabilities = F.softmax(logits, dim=-1)
        
        # Get predicted label and confidence
        confidence, pred_id = torch.max(probabilities, dim=-1)
        predicted_label = label_names[pred_id.item()]

    # 5) Return label and confidence percentage
    return f"Prediction: {predicted_label}, Confidence: {confidence.item() * 100:.2f}%"

# 6) Build Gradio interface
demo = gr.Interface(
    fn=classify_audio,
    inputs=gr.Audio(type="filepath"),
    outputs="text",
    title="Wav2Vec2 Deepfake Detection",
    description="Upload an audio sample to check if it is fake or real, along with confidence."
)

if __name__ == "__main__":
    demo.launch()